0% found this document useful (0 votes)
25 views

Bryan Lim

Uploaded by

fivecit970
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
25 views

Bryan Lim

Uploaded by

fivecit970
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 145

Deep Learning for

Time Series Prediction &


Decision Making Over Time

Bryan Lim
St Cross College
University of Oxford

A thesis submitted for the degree of


Doctor of Philosophy

Trinity 2020
Acknowledgements

Firstly, I would like to express my immense gratitude to my supervisors Dr. Stefan


Zohren and Prof. Stephen Roberts for their guidance throughout my DPhil. Their
support and encouragement has been invaluable in broadening my research horizons
and has helped me deepen my expertise in the field. I would like to thank everyone
I have worked with at the Oxford-Man Institute – both for the provision of funding
support and for being an integral part of my experience at Oxford. I also thank
Xiaowen Dong, Martin Schmalz and Michael Osborne for their feedback during my
confirmation and transfer of status, which has helped improved the quality of this
thesis.
I am grateful to all the research collaborators with whom I have had the pleasure of
working during my DPhil. In particular, I would like to thank Sercan Arik and Nicolas
Loeff at Google Cloud AI for all the thought discussions on time series forecasting,
which sparked many research ideas and lead to the development of the TFT. To
Anthony Ledford and Thomas Flury at Man AHL, thank you for your insights into
systematic trading which were very helpful in the presentation of the DMN. I also
thank Prof. Mihaela van der Schaar and her group, especially Ahmed Alaa, Jinsung
Yoon, James Jordon and Alexis Bellot, for their patience in the early phases of my
studies – their advice and assistance was vital in helping me shape my approach to
research.
Above all, I would like to thank my parents, Goh Soo Cheng and Michael Lim,
and my siblings, Shawn Lim and Michelle Lim, for their infinite support through the
ups and downs of this academic adventure.

2
Abstract

In this thesis, we develop a collection of state-of-the-art deep learning models for time
series forecasting. Primarily focusing on a closer alignment with traditional meth-
ods in time series modelling, we adopt three main directions of research – 1) novel
architectures, 2) hybrid models, and 3) feature extraction. Firstly, we propose two
new architectures for general one-step-ahead and multi-horizon forecasting. With the
Recurrent Neural Filter (RNF), we take a closer look at the relationship between
recurrent neural networks and Bayesian filtering, so as to improve representation
learning for one-step-ahead forecasts. For multi-horizon forecasting, we propose the
Temporal Fusion Transformer (TFT) – an attention-based model designed to accom-
modate the full range of inputs present in common problem scenarios. Secondly,
we investigate the use of hybrid models to enhance traditional quantitative mod-
els with deep learning component – using domain-specific knowledge can be used to
guide neural network training. Through an applications in finance (Deep Momen-
tum Networks) and medicine (Disease-Atlas), we demonstrate that hybrid models
can effectively improve forecasting performance over pure methods in either category.
Finally, we explore the feature learning capabilities of deep neural networks to devise
features for general forecasting models. Considering an application in systemic risk
management, we devise the Autoencoder Reconstruction Ratio (ARR) – an indicator
to measure the degree of co-movement between asset returns. When fed as an in-
put into a variety of models, we show that the ARR can help to improve short-term
predictions of various risk metrics.
On top of improvements in forecasting performance, we also investigate extensions
to enable decision support using deep neural networks, by helping users to better
understand their data. With Recurrent Marginal Structural Networks (RMSNs), we
introduce general framework to train deep neural networks to learn causal effects over
time, using ideas from marginal structural modelling in epidemiology. In addition,
we also propose three practical interpretability use-cases for the TFT, demonstrating
how attention weights can be analysed to provide insights into temporal dynamics.

3
Contents

1 Introduction 1
1.1 Motivations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1
1.2 Contributions & Outline of Report . . . . . . . . . . . . . . . . . . . 2
1.3 Publications . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 6

2 Literature Review 8
2.1 Time Series Forecasting With Deep Learning: A Survey . . . . . . . . 9

3 Novel Architectures For Time Series Forecasting 22


3.1 Temporal Fusion Transformers for Interpretable Multi-horizon Time
Series Forecasting . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 23
3.2 Recurrent Neural Filters: Learning Independent Bayesian Filtering
Steps for Time Series Prediction . . . . . . . . . . . . . . . . . . . . . 50

4 Incorporating Domain Knowledge With Hybrid Models 62


4.1 Enhancing Time Series Momentum Using Deep Neural Networks . . . 63
4.2 Disease-Atlas: Navigating Disease Trajectories Using Deep Learning . 84

5 Learning Time-Dependent Causal Effects 106


5.1 Forecasting Treatment Responses Over Time Using Recurrent Marginal
Structural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . 107

6 Extracting Useful Features From High Frequency Data 124


6.1 Detecting Changes in Asset Co-Movement Using the Autoencoder Re-
construction Ratio . . . . . . . . . . . . . . . . . . . . . . . . . . . . 125

7 Discussion & Conclusions 134


7.1 Extensions and Future Work . . . . . . . . . . . . . . . . . . . . . . . 136

Bibliography 138

i
Chapter 1

Introduction

1.1 Motivations
Time series analysis is an integral part of many disciplines – providing researchers
with the analytical tools required to study temporal dynamics. Ranging from sim-
ple autoregressive models regularly used in econometrics [4], to more sophisticated
particle filtering from signal processing [2], numerous methods have been develop to
make predictions of and reason about time-varying behavior. In addition to achieving
high degrees of predictive accuracy, these traditional methods also provide a degree
of insight into the relationships of the underlying models, given the generative or
parametric nature of the models used. For instance, analyzing p-values of coefficients
in vector autoregressive models allows researchers to identify significant inputs for
the prediction problem, and state space models allow for the inference of otherwise
unobservable latent states.
With the rapid growth in data availability over time, machine learning methods
have increasingly been used for time series prediction tasks in applications such as
retail demand forecasting [6], healthcare [15], and finance [8]. Favoured for their
ability to learn relationships in a purely data driven manner, time series forecasting
with deep learning has risen to prominence in recent times – fueled by phenomenal
successes in sequence modelling [28, 29] and reinforcement learning [27] tasks in other
domains. However, new architectures have largely developed independently from the
field of time series analysis, focusing on alleviating specific limitations in the context of
sequence prediction – e.g. difficulties in learning long-term dependencies in Recurrent
Neural Networks (RNNs) [11]. As such, improving the alignment of neural network
architectures with traditional time series models could help to improve forecasting
accuracy – potentially by taking into account nuances in the dataset, or by improving
representation learning [1].

1
Furthermore, while researchers primarily concentrate on prediction accuracy when
developing models, practitioners rely on model forecasts more to inform their decision
making and optimise their actions over time. In retail applications for instance, stores
can use sales forecasting models to determine product demand, which allows them
to optimise inventory management [24]. As such, providing components which help
users to understand the temporal relationships within their time series datasets can
also be beneficial for decision making. Moreover, with more deep learning systems
featuring in mission critical-applications, such as in patient diagnosis [15] and portfolio
management [8], understanding the key drivers of a model’s prediction can improve
the users trust in its forecast. Given the black-box nature of deep neural networks,
this motivates the need for the development of components which provides insights
into the relationships learnt by the model.

1.2 Contributions & Outline of Report


1.2.1 Research Directions
This thesis focuses on the development of state-of-the-art time deep learning models
for time series forecasting, along with extensions to use deep neural networks to
facilitate decision support. We divide our proposed improvements into the two main
research areas below – with a full outline of research trends in the literature review
of Chapter 2.

Time Series Forecasting: We adopt three main approaches to utilise deep learn-
ing in high-performance time series forecasting applications. For a start, we develop
novel deep learning architectures to improve representation learning in one-step-ahead
and multi-horizon time series forecasting. Next, we demonstrate the use of hybrid
models in a variety of applications, using deep learning components to enhance well-
studied quantitative models for a given domain. Lastly, we assess the use of deep
neural networks as a feature extraction mechanism for high-frequency data – using
autoencoders to extract useful features to feed into other forecasting models.

Facilitating Decision Making Over Time: In addition to improvements in fore-


casting, we also explore extensions to improve the decision support capabilities of deep
neural networks in several research directions. Firstly, we investigate the use of deep
neural networks to learn time-varying causal effects, training them with biased ob-
servational data typically collected in many industrial settings. With causal deep

2
learning models, decision makers can then perform scenario analyses through coun-
terfactual simulations, allowing them to optimise their decisions in scenarios that are
historically uncommon or not present in the data. Secondly, we examine methods that
incorporate explainability into high-performance deep learning architectures, which
allow users to better understand general relationships present within their datasets.
Finally, while we do not explicitly propose new methods for probabilistic modelling,
we design most of our forecasting models to incorporate measures of uncertainty as
well. These uncertainty estimates allow users to determine how confident a model is
in its own forecast, so as to enhance the user’s trust in a model’s outputs.

1.2.2 Thesis Contributions & Outline


According to the research directions outlined in the previous section, we group meth-
ods into chapters based on their main area of contribution:

Novel Architectures: In Chapter 3, we propose two new state-of-the-art deep


learning architectures for time series forecasting. With Recurrent Neural Filters
(RNF) [18], we explore the relationships between recurrent neural networks (RNNs)
and Bayesian filtering – identifying that RNNs can be viewed as a simultaneous ap-
proximation of both state transition and update steps. As such, we propose a novel
RNN architecture that closely aligns with the Bayesian filtering steps, using separate
encoders to capture distinct representations for each step. From a network calibration
standpoint, we improve representation learning by training the RNF in a multitask
fashion – with each encoder to be trained directly using a common emissions de-
coder. In addition, our proposed skip training approach acts as a form of input
dropout, and improves generalisation by allowing encoders to operate independently
from each other. From a filtering point of view, the RNF encoders can be separated
at run-time based the availability of inputs, allowing them to be used in forecasting
with missing data or in an autoregressive fashion for multi-horizon prediction. Con-
sidering applications in finance and electricity load forecasting, we demonstrate the
performance improvements over other recurrent models for both one-step-ahead and
multi-horizon forecasts, while providing comparable uncertainty estimates.
Despite their simplicity, autoregressive models for multi-horizon prediction typi-
cally assume that inputs are homogeneously available across time. However, many
practical applications of multi-horizon forecasting typically contain a complex mix
of inputs – including static covariates, known future inputs and other time series

3
that can only be observed in the past. To account for the full range of inputs avail-
able, we propose a new multi-horizon forecasting model we call the Temporal Fusion
Transformer (TFT) [14], which combines state-of-the-art performance with inherent
interpretability. To capture temporal representations at different scales, the TFT uses
a sequence-to-sequence layer for local processing and interpretable attention blocks to
learn long-term dependencies. On top of temporal self-attention layers that attend to
important time steps in the lookback window, the TFT also facilitates interpretability
by using variable selection networks to select relevant inputs for the prediction task.
Using a range of real-world datasets, we demonstrate significant performance im-
provements over other multi-horizon forecasting models, and present three practical
interpretability use-cases for the TFT.

Hybrid Models: While the flexibility of deep neural networks allows them to learn
complex relationships in a data-driven fashion, they also can face issues with overfit-
ting in noisy time series datasets [19] – particularly when datasets are small [22]. To
alleviate this, a recent trend in deep learning research has emerged in the development
of hybrid models, which parameterise well-studied quantitative models for a given do-
main using neural network components. Hybrid models allow for model builders to
inject prior knowledge into neural network designs, guiding network calibration by re-
stricting the form of the output learnt by the model. In Chapter 4, we introduce two
new hybrid models which customised for domain-specific applications. With Deep
Momentum Networks (DMNs) [16], we consider an application in finance – adopt-
ing the use of hybrid models for systematic trading. DMNs extend common trend
following signals using deep learning, using neural network components to generate
trading rules which are combined with the volatility scaling framework of time series
momentum. We also consider a medical application with the Disease-Atlas [15], which
utilises a hybrid model for co-morbidity management in patients with Cystic Fibrosis.
The Disease-Atlas extends joint models for longitudinal and time-to-event data com-
monly used in biostatistics – using a shared network architecture to learn associations
between longitudinal variables and generate parameters for predictive distributions
of each variable. This allows medical professionals to jointly forecast multiple clinical
outcomes, including the survival and disease onset probabilities along with expected
biomarker values. For both DMNs and Disease-Atlas, we demonstrate performance
improvements over pure quantitative or deep learning benchmarks – highlighting the
benefits of the hybrid modelling approach.

4
Causal Inference Over Time: In line with new methods to facilitate decision sup-
port, Chapter 5 investigates the use of deep neural networks to learn time-dependent
causal effects. Motivated by increasing prevalence of electronic health records in
hospitals, we focus on the problem of forecasting time-dependent treatment effects
in medical settings – although similar methods can be used in other domains. The
key challenge with observational data is the presence of time-dependent confounders,
which can introduce bias when no adjustments are made to neural network training.
To account for time-dependent counfounding effects, we developed a class of mod-
els known as Recurrent Marginal Structural Networks (RMSN) [13], which extend
marginal structural models (MSMs) in epidemiology for neural network training.
RMSNs adopt an inverse probability of treatment weighting (IPTW) approach to
adjust for time-dependent confounding – using a set of networks to estimate proba-
bilities of treatment application and censoring at each time step. Propensity weights
are computed based on these probabilities, which then adjust the loss function of a
separate sequence-to-sequence model that predicts treatment effects. Using a clinical
realistic simulation model for non-small cell lung cancer, we demonstrate that RM-
SNs improves the performance of both standard MSMs and other machine learning
baselines in learning unbiased treatment effects.

Feature Extraction from High Frequency Data: Apart from being used di-
rectly for prediction tasks, the representation learning capabilities of deep neural
networks have also been re-purposed for high-quality feature extraction across a vari-
ety of tasks [7, 30, 23]. Considering a financial use-case in systemic risk prediction, we
explore the use of autoencoders to extract useful features for long-term predictions
in Chapter 6. In the spirit of traditional techniques that compute real-time esti-
mates from high-frequency data – such as HEAVY models [25, 21, 26] which estimate
volatility using intra-day returns – we present the Autoencoder Reconstruction Ratio
(ARR) [17] as an early-warning indicator of changes in asset correlations. Using deep
sparse denoising autoencoder for dimensionality reduction, the ARR is computed by
aggregating the reconstruction errors of high-frequency returns over various horizons.
Through tests on intra-day index returns, we demonstrate that autoencoders do pro-
vide a better model for high-frequency data – reducing reconstruction errors when
compared to simpler methods based on principal component analysis. Furthermore,
we show that ARR can enhance the accuracy of short-term predictions of volatility
and market crashes when included as an additional input into forecasting models.

5
1.2.2.1 Presentation Format

This thesis is presented in the integrated thesis format 1 , with the main chapters
consisting of first-author papers that have been accepted or are under consideration
for publication. This includes the literature review, which has been submitted as a
survey paper. To ensure compliance with any copyright restrictions, the papers have
been reproduced in the original format used in their pre-print versions.

1.3 Publications
At the time of writing, the papers presented in this thesis have been accepted for
publication at the venues below:

Ch. 2: Literature Review


• B. Lim, S. Zohren. Time Series Forecasting With Deep Learning: A Survey.
Philosophical Transactions of the Royal Society of London A, 2020. Weblink:
https://arxiv.org/abs/2004.13408.

Ch. 3: Novel Architectures for Time Series Forecasting


• B. Lim, S. O. Arik, N. Loeff, T. Pfister. Temporal Fusion Transformers for In-
terpretable Multi-horizon Time Series Forecasting. Submitted, 2020. Weblink:
https://arxiv.org/abs/1912.09363.

• B. Lim, S. Zohren, S. Roberts. Recurrent Neural Filters: Learning Indepen-


dent Bayesian Filtering Steps for Time Series Prediction. Proceedings of the
International Joint Conference On Neural Networks (IJCNN), 2020. Weblink:
https://arxiv.org/abs/1901.08096.

Ch. 4: Incorporating Domain Knowledge With Hybrid Mod-


els
• B. Lim, S. Zohren, S. Roberts. Enhancing Time Series Momentum Strategies
Using Deep Neural Networks. Journal of Financial Data Science (JFDS), 2019.
Weblink: https://arxiv.org/abs/1904.04912.
1
https://www.mpls.ox.ac.uk/graduate-school/information-for-postgraduate-
research-students/submitting-your-thesis

6
• B. Lim, M. van der Schaar. Disease-Atlas: Navigating Disease Trajectories
with Deep Learning. Proceedings of the Machine Learning for Healthcare Con-
ference (MLHC), 2018. Weblink: https://arxiv.org/abs/1803.10254.

The Disease-Atlas has also been presented at the following workshops:

• B. Lim, M. van der Schaar. Disease-Atlas: Navigating Disease Trajectories


with Deep Learning. IJCAI International Workshop on Biomedical informatics
with Optimization and Machine learning (IJCAI-BOOM), 2018. [Best Paper
Award]

• B. Lim, M. van der Schaar. Forecasting Disease Trajectories in Alzheimer’s


Disease Using Deep Learning. KDD Workshop on Machine Learning for Medicine
and Healthcare, 2018.

• B. Lim, T. Daniels, A. Floto, M. van der Schaar. Forecasting Clinical Trajec-


tories in Cystic Fibrosis using Deep Learning. North American Cystic Fibrosis
Conference, 2018.

Ch. 5: Learning Time-Dependent Causal Effects


• B. Lim, A. Alaa, M. van der Schaar. Forecasting Treatment Responses Over
Time Using Marginal Structural Models. Advances in Neural Information Pro-
cessing Systems (NeurIPS), 2018. Weblink: http://papers.nips.cc/paper/
7977-forecasting-treatment-responses-over-time-using-recurrent-marginal-
structural-networks.

Ch. 6: Extracting Useful Features From High Frequency Data


• B. Lim, S. Zohren, S. Roberts. Detecting Changes in Asset Co-Movement
Using the Autoencoder Reconstruction Ratio. Risk, 2020. Weblink: https:
//arxiv.org/abs/2002.02008.

7
Chapter 2

Literature Review

Publications Included
• B. Lim, S. Zohren. Time Series Forecasting With Deep Learning: A Survey.
Philosophical Transactions of the Royal Society A, 2020. Weblink: https:
//arxiv.org/abs/2004.13408.

8
Time Series Forecasting With
Deep Learning: A Survey
rsta.royalsocietypublishing.org Bryan Lim1 and Stefan Zohren1
1 Department of Engineering Science, University of

Oxford, Oxford, UK
Research
Numerous deep learning architectures have been
Article submitted to journal developed to accommodate the diversity of time series
datasets across different domains. In this article, we
survey common encoder and decoder designs used
Subject Areas: in both one-step-ahead and multi-horizon time series
forecasting – describing how temporal information is
Deep learning, time series modelling
incorporated into predictions by each model. Next, we
highlight recent developments in hybrid deep learning
Keywords:
models, which combine well-studied statistical models
Deep neural networks, time series with neural network components to improve pure
forecasting, uncertainty estimation, methods in either category. Lastly, we outline some
hybrid models, interpretability, ways in which deep learning can also facilitate decision
counterfactual prediction support with time series data.

Author for correspondence:


1. Introduction
Bryan Lim Time series modelling has historically been a key area
of academic research – forming an integral part of
e-mail: [email protected]
applications in topics such as climate modelling [1],
biological sciences [2] and medicine [3], as well as
commercial decision making in retail [4] and finance [5] to
name a few. While traditional methods have focused on
parametric models informed by domain expertise – such
as autoregressive (AR) [6], exponential smoothing [7, 8]
or structural time series models [9] – modern machine
learning methods provide a means to learn temporal
dynamics in a purely data-driven manner [10]. With
the increasing data availability and computing power in
recent times, machine learning has become a vital part of
the next generation of time series forecasting models.
Deep learning in particular has gained popularity
in recent times, inspired by notable achievements in
image classification [11], natural language processing
[12] and reinforcement learning [13]. By incorporating
bespoke architectural assumptions – or inductive biases
[14] – that reflect the nuances of underlying datasets,
deep neural networks are able to learn complex data
representations [15], which alleviates the need for manual
feature engineering and model design. The availability
of open-source backpropagation frameworks [16, 17] has
also simplified the network training, allowing for the
customisation for network components and loss functions.
© The Authors. Published by the Royal Society under the terms of the
Creative Commons Attribution License http://creativecommons.org/licenses/
by/4.0/, which permits unrestricted use, provided the original author and
source are credited.
Given the diversity of time-series problems across various domains, numerous neural network 2
design choices have emerged. In this article, we summarise the common approaches to time
series prediction using deep neural networks. Firstly, we describe the state-of-the-art techniques

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
available for common forecasting problems – such as multi-horizon forecasting and uncertainty
estimation. Secondly, we analyse the emergence of a new trend in hybrid models, which combine
both domain-specific quantitative models with deep learning components to improve forecasting
performance. Next, we outline two key approaches in which neural networks can be used to
facilitate decision support, specifically through methods in interpretability and counterfactual
prediction. Finally, we conclude with some promising future research directions in deep learning
for time series prediction – specifically in the form of continuous-time and hierarchical models.
While we endeavour to provide a comprehensive overview of modern methods in deep learning,
we note that our survey is by no means all-encompassing. Indeed, a rich body of literature exists for
automated approaches to time series forecasting - including automatic parametric model selection
[18], and traditional machine learning methods such as kernel regression [19] and support vector
regression [20]. In addition, Gaussian processes [21] have been extensively used for time series
prediction – with recent extensions including deep Gaussian processes [22], and parallels in deep
learning via neural processes [23]. Furthermore, older models of neural networks have been used
historically in time series applications, as seen in [24] and [25].

2. Deep Learning Architectures for Time Series Forecasting


Time series forecasting models predict future values of a target yi,t for a given entity i at time t.
Each entity represents a logical grouping of temporal information – such as measurements from
individual weather stations in climatology, or vital signs from different patients in medicine – and
can be observed at the same time. In the simplest case, one-step-ahead forecasting models take the
form:
ŷi,t+1 = f (yi,t−k:t , xi,t−k:t , si ), (2.1)

where ŷi,t+1 is the model forecast, yi,t−k:t = {yi,t−k , . . . , yi,t }, xi,t−k:t = {xi,t−k , . . . , xi,t } are
observations of the target and exogenous inputs respectively over a look-back window k, si is
static metadata associated with the entity (e.g. sensor location), and f (.) is the prediction function
learnt by the model. While we focus on univariate forecasting in this survey (i.e. 1-D targets), we
note that the same components can be extended to multivariate models without loss of generality
[26, 27, 28, 29, 30]. For notational simplicity, we omit the entity index i in subsequent sections
unless explicitly required.

(a) Basic Building Blocks


Deep neural networks learn predictive relationships by using a series of non-linear layers to
construct intermediate feature representations [15]. In time series settings, this can be viewed as
encoding relevant historical information into a latent variable zt , with the final forecast produced
using zt alone:

f (yt−k:t , xt−k:t , s) = gdec (zt ), (2.2)


zt = genc (yt−k:t , xt−k:t , s), (2.3)

where genc (.), gdec (.) are encoder and decoder functions respectively, and recalling that that
subscript i from Equation (2.1) been removed to simplify notation (e.g. yi,t replaced by yt ). These
encoders and decoders hence form the basic building blocks of deep learning architectures, with
the choice of network determining the types of relationships that can be learnt by our model. In
this section, we examine modern design choices for encoders, as overviewed in Figure 1, and their
relationship to traditional temporal models. In addition, we explore common network outputs and
loss functions used in time series forecasting applications.
3

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
(a) CNN Model. (b) RNN Model. (c) Attention-based Model.

Figure 1: Incorporating temporal information using different encoder architectures.

(i) Convolutional Neural Networks


Traditionally designed for image datasets, convolutional neural networks (CNNs) extract local
relationships that are invariant across spatial dimensions [11, 31]. To adapt CNNs to time series
datasets, researchers utilise multiple layers of causal convolutions [32, 33, 34] – i.e. convolutional
filters designed to ensure only past information is used for forecasting. For an intermediate feature
at hidden layer l, each causal convolutional filter takes the form below:
 
hl+1
t = A (W ∗ h) (l, t) , (2.4)

k
X
(W ∗ h) (l, t) = W (l, τ )hlt−τ , (2.5)
τ =0

where hlt ∈ RHin is an intermediate state at layer l at time t, ∗ is the convolution operator, W (l, τ ) ∈
RHout ×Hin is a fixed filter weight at layer l, and A(.) is an activation function, such as a sigmoid
function, representing any architecture-specific non-linear processing. For CNNs that use a total of
L convolutional layers, we note that the encoder output is then zt = hL t .
Considering the 1-D case, we can see that Equation (2.5) bears a strong resemblance to finite
impulse response (FIR) filters in digital signal processing [35]. This leads to two key implications
for temporal relationships learnt by CNNs. Firstly, in line with the spatial invariance assumptions
for standard CNNs, temporal CNNs assume that relationships are time-invariant – using the same
set of filter weights at each time step and across all time. In addition, CNNs are only able to use
inputs within its defined lookback window, or receptive field, to make forecasts. As such, the
receptive field size k needs to be tuned carefully to ensure that the model can make use of all
relevant historical information. It is worth noting that a single causal CNN layer with a linear
activation function is equivalent to an auto-regressive (AR) model.

Dilated Convolutions Using standard convolutional layers can be computationally challenging


where long-term dependencies are significant, as the number of parameters scales directly with the
size of the receptive field. To alleviate this, modern architectures frequently make use of dilated
covolutional layers [32, 33], which extend Equation (2.5) as below:

bk/dl c
X
(W ∗ h) (l, t, dl ) = W (l, τ )hlt−dl τ , (2.6)
τ =0

where b.c is the floor operator and dl is a layer-specific dilation rate. Dilated convolutions can hence
be interpreted as convolutions of a down-sampled version of the lower layer features – reducing
resolution to incorporate information from the distant past. As such, by increasing the dilation rate
with each layer, dilated convolutions can gradually aggregate information at different time blocks,
allowing for more history to be used in an efficient manner. With the WaveNet architecture of [32]
for instance, dilation rates are increased in powers of 2 with adjacent time blocks aggregated in
each layer – allowing for 2l time steps to be used at layer l as shown in Figure 1a.
(ii) Recurrent Neural Networks 4
Recurrent neural networks (RNNs) have historically been used in sequence modelling [31],

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


with strong results on a variety of natural language processing tasks [36]. Given the natural

..................................................................
interpretation of time series data as sequences of inputs and targets, many RNN-based architectures
have been developed for temporal forecasting applications [37, 38, 39, 40]. At its core, RNN cells
contain an internal memory state which acts as a compact summary of past information. The
memory state is recursively updated with new observations at each time step as shown in Figure
1b, i.e.:

zt = ν (zt−1 , yt , xt , s) , (2.7)

Where zt ∈ RH here is the hidden internal state of the RNN, and ν(.) is the learnt memory update
function. For instance, the Elman RNN [41], one of the simplest RNN variants, would take the
form below:

yt+1 = γy (Wy zt + by ), (2.8)


zt = γz (Wz1 zt−1 + Wz2 yt + Wz3 xt + Wz4 s + bz ), (2.9)

Where W. , b. are the linear weights and biases of the network respectively, and γy (.), γz (.) are
network activation functions. Note that RNNs do not require the explicit specification of a lookback
window as per the CNN case. From a signal processing perspective, the main recurrent layer – i.e.
Equation (2.9) – thus resembles a non-linear version of infinite impulse response (IIR) filters.

Long Short-term Memory Due to the infinite lookback window, older variants of RNNs can
suffer from limitations in learning long-range dependencies in the data [42, 43] – due to issues with
exploding and vanishing gradients [31]. Intuitively, this can be seen as a form of resonance in the
memory state. Long Short-Term Memory networks (LSTMs) [44] were hence developed to address
these limitations, by improving gradient flow within the network. This is achieved through the use
of a cell state ct which stores long-term information, modulated through a series of gates as below:

Input gate: it = σ(Wi1 zt−1 + Wi2 yt + Wi3 xt + Wi4 s + bi ), (2.10)


Output gate: ot = σ(Wo1 zt−1 + Wo2 yt + Wo3 xt + Wo4 s + bo ), (2.11)
Forget gate: ft = σ(Wf1 zt−1 + Wf2 yt + Wf3 xt + Wf4 s + bf ), (2.12)

where zt−1 is the hidden state of the LSTM, and σ(.) is the sigmoid activation function. The gates
modify the hidden and cell states of the LSTM as below:

Hidden state: zt = ot tanh(ct ), (2.13)


Cell state: ct = ft ct−1
+ it tanh(Wc1 zt−1 + Wc2 yt + Wc3 xt + Wc4 s + bc ), (2.14)

Where is the element-wise (Hadamard) product, and tanh(.) is the tanh activation function.

Relationship to Bayesian Filtering As examined in [39], Bayesian filters [45] and RNNs are both
similar in their maintenance of a hidden state which is recursively updated over time. For Bayesian
filters, such as the Kalman filter [46], inference is performed by updating the sufficient statistics
of the latent state – using a series of state transition and error correction steps. As the Bayesian
filtering steps use deterministic equations to modify sufficient statistics, the RNN can be viewed
as a simultaneous approximation of both steps – with the memory vector containing all relevant
information required for prediction.
(iii) Attention Mechanisms 5
The development of attention mechanisms [47, 48] has also lead to improvements in long-term
dependency learning – with Transformer architectures achieving state-of-the-art performance in

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
multiple natural language processing applications [12, 49, 50]. Attention layers aggregate temporal
features using dynamically generated weights (see Figure 1c), allowing the network to directly
focus on significant time steps in the past – even if they are very far back in the lookback window.
Conceptually, attention is a mechanism for a key-value lookup based on a given query [51], taking
the form below:
k
X
ht = α(κt , qτ )vt−τ , (2.15)
τ =0

Where the key κt , query qτ and value vt−τ are intermediate features produced at different time
steps by lower levels of the network. Furthermore, α(κt , qτ ) ∈ [0, 1] is the attention weight for
t − τ generated at time t, and ht is the context vector output of the attention layer. Note that
multiple attention layers can also be used together as per the CNN case, with the output from the
final layer forming the encoded latent variable zt .
Recent work has also demonstrated the benefits of using attention mechanisms in time series
forecasting applications, with improved performance over comparable recurrent networks [52,
53, 54]. For instance, [52] use attention to aggregate features extracted by RNN encoders, with
attention weights produced as below:

α(t) = softmax(ηt ), (2.16)


ηt = Wη1 tanh(Wη2 κt−1 + Wη3 qτ + bη ), (2.17)

where α(t) = [α(t, 0), . . . α(t, k)] is a vector of attention weights, κt−1 , qt are outputs from LSTM
encoders used for feature extraction, and softmax(.) is the softmax activation function. More
recently, Transformer architectures have also been considered in [53, 54], which apply scalar-dot
product self-attention [49] to features extracted within the lookback window. From a time series
modelling perspective, attention provides two key benefits. Firstly, networks with attention are
able to directly attend to any significant events that occur. In retail forecasting applications, for
example, this includes holiday or promotional periods which can have a positive effect on sales.
Secondly, as shown in [54], attention-based networks can also learn regime-specific temporal
dynamics – by using distinct attention weight patterns for each regime.

(iv) Outputs and Loss Functions


Given the flexibility of neural networks, deep neural networks have been used to model both
discrete [55] and continuous [37, 56] targets – by customising of decoder and output layer of the
neural network to match the desired target type. In one-step-ahead prediction problems, this
can be as simple as combining a linear transformation of encoder outputs (i.e. Equation (2.2))
together with an appropriate output activation for the target. Regardless of the form of the target,
predictions can be further divided into two different categories – point estimates and probabilistic
forecasts.

Point Estimates A common approach to forecasting is to determine the expected value of a


future target. This essentially involves reformulating the problem to a classification task for
discrete outputs (e.g. forecasting future events), and regression task for continuous outputs – using
the encoders described above. For the binary classification case, the final layer of the decoder then
features a linear layer with a sigmoid activation function – allowing the network to predict the
probability of event occurrence at a given time step. For one-step-ahead forecasts of binary and
continuous targets, networks are trained using binary cross-entropy and mean square error loss
functions respectively: 6
T
1 X
Lclassif ication = − yt log(ŷt ) + (1 − yt ) log(1 − ŷt ) (2.18)

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
T
t=1
T
1 X
Lregression = (yt − ŷt )2 (2.19)
T
t=1

While the loss functions above are the most common across applications, we note that the
flexibility of neural networks also allows for more complex losses to be adopted - e.g. losses for
quantile regression [56] and multinomial classification [32].

Probabilistic Outputs While point estimates are crucial to predicting the future value of a target,
understanding the uncertainty of a model’s forecast can be useful for decision makers in different
domains. When forecast uncertainties are wide, for instance, model users can exercise more caution
when incorporating predictions into their decision making, or alternatively rely on other sources
of information. In some applications, such as financial risk management, having access to the full
predictive distribution will allow decision makers to optimise their actions in the presence of rare
events – e.g. allowing risk managers to insulate portfolios against market crashes.
A common way to model uncertainties is to use deep neural networks to generate parameters
of known distributions [27, 37, 38]. For example, Gaussian distributions are typically used for
forecasting problems with continuous targets, with the networks outputting means and variance
parameters for the predictive distributions at each step as below:

yt+τ ∼ N (µ(t, τ ), ζ(t, τ )2 ), (2.20)

µ(t, τ ) = Wµ hL
t + bµ , (2.21)

ζ(t, τ ) = softplus(WΣ hL
t + bΣ ), (2.22)

where hLt is the final layer of the network, and softplus(.) is the softplus activation function to
ensure that standard deviations take only positive values.

(b) Multi-horizon Forecasting Models


In many applications, it is often beneficial to have access to predictive estimates at multiple points
in the future – allowing decision makers to visualise trends over a future horizon, and optimise
their actions across the entire path. From a statistical perspective, multi-horizon forecasting can be
viewed as a slight modification of one-step-ahead prediction problem (i.e. Equation (2.1)) as below:

ŷt+τ = f (yt−k:t , xt−k:t , ut−k:t+τ , s, τ ), (2.23)

where τ ∈ {1, . . . , τmax } is a discrete forecast horizon, ut are known future inputs (e.g. date
information, such as the day-of-week or month) across the entire horizon, and xt are inputs
that can only be observed historically. In line with traditional econometric approaches [57, 58],
deep learning architectures for multi-horizon forecasting can be divided into iterative and direct
methods – as shown in Figure 2 and described in detail below.

(i) Iterative Methods


Iterative approaches to multi-horizon forecasting typically make use of autoregressive deep
learning architectures [37, 39, 40, 53] – producing multi-horizon forecasts by recursively feeding
samples of the target into future time steps (see Figure 2a). By repeating the procedure to generate
multiple trajectories, forecasts are then produced using the sampling distributions for target values
at each step. For instance, predictive means can be obtained using the Monte Carlo estimate
P (j) (j)
ŷt+τ = Jj=1 ỹt+τ /J, where ỹt+τ is a sample taken based on the model of Equation (2.20). As
autoregressive models are trained in the exact same fashion as one-step-ahead prediction models
7

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
(a) Iterative Methods (b) Direct Methods

Figure 2: Main types of multi-horizon forecasting models. Colours used to distinguish between
model weights – with iterative models using a common model across the entire horizon and direct
methods taking a sequence-to-sequence approach.

(i.e. via backpropagation through time), the iterative approach allows for the easy generalisation
of standard models to multi-step forecasting. However, as a small amount of error is produced
at each time step, the recursive structure of iterative methods can potentially lead to large error
accumulations over longer forecasting horizons. In addition, iterative methods assume that all
inputs but the target are known at run-time – requiring only samples of the target to be fed into
future time steps. This can be a limitation in many practical scenarios where observed inputs exist,
motivating the need for more flexible methods.

(ii) Direct Methods


Direct methods alleviate the issues with iterative methods by producing forecasts directly using all
available inputs. They typically make use of sequence-to-sequence architectures [52, 54, 56], using
an encoder to summarise past information (i.e. targets, observed inputs and a priori known inputs),
and a decoder to combine them with known future inputs – as depicted in Figure 2b. As described
in [59], alternative approach is to use simpler models to directly produce a fixed-length vector
matching the desired forecast horizon. This, however, does require the specification of a maximum
forecast horizon (i.e. τmax ), with predictions made only at the predefined discrete intervals.

3. Incorporating Domain Knowledge with Hybrid Models


Despite its popularity, the efficacy of machine learning for time series prediction has historically
been questioned – as evidenced by forecasting competitions such as the M-competitions [60]. Prior
to the M4 competition of 2018 [61], the prevailing wisdom was that sophisticated methods do not
produce more accurate forecasts, and simple models with ensembling had a tendency to do better
[59, 62, 63]. Two key reasons have been identified to explain the underperformance of machine
learning methods. Firstly, the flexibility of machine learning methods can be a double-edged sword
– making them prone to overfitting [59]. Hence, simpler models may potentially do better in low
data regimes, which are particularly common in forecasting problems with a small number of
historical observations (e.g. quarterly macroeconomic forecasts). Secondly, similar to stationarity
requirements of statistical models, machine learning models can be sensitive to how inputs are
pre-processed [26, 37, 59], which ensure that data distributions at training and test time are similar.
A recent trend in deep learning has been in developing hybrid models which address these
limitations, demonstrating improved performance over pure statistical or machine learning models
in a variety of applications [38, 64, 65, 66]. Hybrid methods combine well-studied quantitative
time series models together with deep learning – using deep neural networks to generate model
parameters at each time step. On the one hand, hybrid models allow domain experts to inform
neural network training using prior information – reducing the hypothesis space of the network
and improving generalisation. This is especially useful for small datasets [38], where there is a
greater risk of overfitting for deep learning models. Furthermore, hybrid models allow for the
separation of stationary and non-stationary components, and avoid the need for custom input
pre-processing. An example of this is the Exponential Smoothing RNN (ES-RNN) [64], winner
of the M4 competition, which uses exponential smoothing to capture non-stationary trends and
learns additional effects with the RNN. In general, hybrid models utilise deep neural networks 8
in two manners: a) to encode time-varying parameters for non-probabilistic parametric models
[64, 65, 67], and b) to produce parameters of distributions used by probabilistic models [38, 40, 66].

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
(a) Non-probabilistic Hybrid Models
With parametric time series models, forecasting equations are typically defined analytically and
provide point forecasts for future targets. Non-probabilistic hybrid models hence modify these
forecasting equations to combine statistical and deep learning components. The ES-RNN for
example, utilises the update equations of the Holt-Winters exponential smoothing model [8] –
combining multiplicative level and seasonality components with deep learning outputs as below:

ŷi,t+τ = exp(WES hL
i,t+τ + bES ) × li,t × γi,t+τ , (3.1)
(i) (i)
li,t = β1 yi,t /γi,t + (1 − β1 )li,t−1 , (3.2)
(i) (i)
γi,t = β2 yi,t /li,t + (1 − β2 )γi,t−κ , (3.3)

where hL
i,t+τ is the final layer of the network for the τ th-step-ahead forecast, li,t is a level
(i) (i)
component, γi,t is a seasonality component with period κ, and β1 , β2 are entity-specific static
coefficients. From the above equations, we can see that the exponential smoothing components
(li,t , γi,t ) handle the broader (e.g. exponential) trends within the datasets, reducing the need for
additional input scaling.

(b) Probabilistic Hybrid Models


Probabilistic hybrid models can also be used in applications where distribution modelling is
important – utilising probabilistic generative models for temporal dynamics such as Gaussian
processes [40] and linear state space models [38]. Rather than modifying forecasting equations,
probabilistic hybrid models use neural networks to produce parameters for predictive distributions
at each step. For instance, Deep State Space Models [38] encode time-varying parameters for linear
state space models as below – performing inference via the Kalman filtering equations [46]:

yt = a(hL T L
i,t+τ ) lt + φ(hi,t+τ )t , (3.4)

lt = F (hL
i,t+τ )lt−1 + q(hL L
i,t+τ ) + Σ(hi,t+τ ) Σt , (3.5)

where lt is the hidden latent state, a(.), F (.), q(.) are linear transformations of hL
i,t+τ ,
φ(.), Σ(.)
are linear transformations with softmax activations, t ∼ N (0, 1) is a univariate residual and
Σt ∼ N (0, I) is a multivariate normal random variable.

4. Facilitating Decision Support Using Deep Neural Networks


Although model builders are mainly concerned with the accuracy of their forecasts, end-users
typically use predictions to guide their future actions. For instance, doctors can make use of clinical
forecasts (e.g. probabilities of disease onset and mortality) to help them prioritise tests to order,
formulate a diagnosis and determine a course of treatment. As such, while time series forecasting is
a crucial preliminary step, a better understanding of both temporal dynamics and the motivations
behind a model’s forecast can help users further optimise their actions. In this section, we explore
two directions in which neural networks have been extended to facilitate decision support with
time series data – focusing on methods in interpretability and causal inference.

(a) Interpretability With Time Series Data


With the deployment of neural networks in mission-critical applications [68], there is a increasing
need to understand both how and why a model makes a certain prediction. Moreover, end-users can
have little prior knowledge with regards to the relationships present in their data, with datasets 9
growing in size and complexity in recent times. Given the black-box nature of standard neural
network architectures, a new body of research has emerged in methods for interpreting deep

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
learning models. We present a summary below – referring the reader to dedicated surveys for
more in-depth analyses [69, 70].

Techniques for Post-hoc Interpretability Post-hoc interpretable models are developed to


interpret trained networks, and helping to identify important features or examples without
modifying the original weights. Methods can mainly be divided into two main categories. Firstly,
one possible approach is to apply simpler interpretable surrogate models between the inputs and
outputs of the neural network, and rely on the approximate model to provide explanations. For
instance, Local Interpretable Model-Agnostic Explanations (LIME) [71] identify relevant features
by fitting instance-specific linear models to perturbations of the input, with the linear coefficients
providing a measure of importance. Shapley additive explanations (SHAP) [72] provide another
surrogate approach, which utilises Shapley values from cooperative game theory to identify
important features across the dataset. Next, gradient-based method – such as saliency maps [73, 74]
and influence functions [75] – have been proposed, which analyse network gradients to determine
which input features have the greatest impact on loss functions. While post-hoc interpretability
methods can help with feature attributions, they typically ignore any sequential dependencies
between inputs – making it difficult to apply them to complex time series datasets.

Inherent Interpretability with Attention Weights An alternative approach is to directly design


architectures with explainable components, typically in the form of strategically placed attention
layers. As attention weights are produced as outputs from a softmax layer, the weights are
P
constrained to sum to 1, i.e. kτ=0 α(t, τ ) = 1. For time series models, the outputs of Equation (2.15)
can hence also be interpreted as a weighted average over temporal features, using the weights
supplied by the attention layer at each step. An analysis of attention weights can then be used to
understand the relative importance of features at each time step. Instance-wise interpretability
studies have been performed in [53, 55, 76], where the authors used specific examples to show how
the magnitudes of α(t, τ ) can indicate which time points were most significant for predictions. By
analysing distributions of attention vectors across time, [54] also shows how attention mechanisms
can be used to identify persistent temporal relationships – such as seasonal patterns – in the dataset.

(b) Counterfactual Predictions & Causal Inference Over Time


In addition to understanding the relationships learnt by the networks, deep learning can also help
to facilitate decision support by producing predictions outside of their observational datasets, or
counterfactual forecasts. Counterfactual predictions are particularly useful for scenario analysis
applications – allowing users to evaluate how different sets of actions can impact target trajectories.
This can be useful both from a historical angle, i.e. determining what would have happened if a
different set of circumstances had occurred, and from a forecasting perspective, i.e. determining
which actions to take to optimise future outcomes.
While a large class of deep learning methods exists for estimating causal effects in static
settings [77, 78, 79], the key challenge in time series datasets is the presence of time-dependent
confounding effects. This arises due to circular dependencies when actions that can affect the
target are also conditional on observations of the target. Without any adjusting for time-dependent
confounders, straightforward estimations techniques can results in biased results, as shown in [80].
Recently, several methods have emerged to train deep neural networks while adjusting for time-
dependent confounding, based on extensions of statistical techniques and the design of new loss
functions. With statistical methods, [81] extends the inverse-probability-of-treatment-weighting
(IPTW) approach of marginal structural models in epidemiology – using one set of networks to
estimate treatment application probabilities, and a sequence-to-sequence model to learn unbiased
predictions. Another approach in [82] extends the G-computation framework, jointly modelling
distributions of the target and actions using deep learning. In addition, new loss functions have 10
been proposed in [83], which adopts domain adversarial training to learn balanced representations
of patient history.

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
5. Conclusions and Future Directions
With the growth in data availability and computing power in recent times, deep neural networks
architectures have achieved much success in forecasting problems across multiple domains. In
this article, we survey the main architectures used for time series forecasting – highlighting the
key building blocks used in neural network design. We examine how they incorporate temporal
information for one-step-ahead predictions, and describe how they can be extended for use in
multi-horizon forecasting. Furthermore, we outline the recent trend of hybrid deep learning models,
which combine statistical and deep learning components to outperform pure methods in either
category. Finally, we summarise two ways in which deep learning can be extended to improve
decision support over time, focusing on methods in interpretability and counterfactual prediction.
Although a large number of deep learning models have been developed for time series
forecasting, some limitations still exist. Firstly, deep neural networks typically require time series to
be discretised at regular intervals, making it difficult to forecast datasets where observations can be
missing or arrive at random intervals. While some preliminary research on continuous-time models
has been done via Neural Ordinary Differential Equations [84], additional work needs to be done
to extend this work for datasets with complex inputs (e.g. static variables) and to benchmark them
against existing models. In addition, as mentioned in [85], time series often have a hierarchical
structure with logical groupings between trajectories – e.g. in retail forecasting, where product
sales in the same geography can be affected by common trends. As such, the development of
architectures which explicit account for such hierarchies could be an interesting research direction,
and potentially improve forecasting performance over existing univariate or multivariate models.

Competing Interests. The author(s) declare that they have no competing interests.

References
1 Mudelsee M. Trend analysis of climate time series: A review of methods. Earth-Science Reviews.
2019;190:310 – 322.
2 Stoffer DS, Ombao H. Editorial: Special issue on time series analysis in the biological sciences.
Journal of Time Series Analysis. 2012;33(5):701–703.
3 Topol EJ. High-performance medicine: the convergence of human and artificial intelligence.
Nature Medicine. 2019 Jan;25(1):44–56.
4 Böse JH, Flunkert V, Gasthaus J, Januschowski T, Lange D, Salinas D, et al. Probabilistic Demand
Forecasting at Scale. Proc VLDB Endow. 2017 Aug;10(12):1694–1705.
5 Andersen TG, Bollerslev T, Christoffersen PF, Diebold FX. Volatility Forecasting. National
Bureau of Economic Research; 2005. 11188.
6 Box GEP, Jenkins GM. Time Series Analysis: Forecasting and Control. Holden-Day; 1976.
7 Gardner Jr ES. Exponential smoothing: The state of the art. Journal of Forecasting. 1985;4(1):1–28.
8 Winters PR. Forecasting Sales by Exponentially Weighted Moving Averages. Management
Science. 1960;6(3):324–342.
9 Harvey AC. Forecasting, Structural Time Series Models and the Kalman Filter. Cambridge
University Press; 1990.
10 Ahmed NK, Atiya AF, Gayar NE, El-Shishiny H. An Empirical Comparison of Machine Learning
Models for Time Series Forecasting. Econometric Reviews. 2010;29(5-6):594–621.
11 Krizhevsky A, Sutskever I, Hinton GE. ImageNet Classification with Deep Convolutional
Neural Networks. In: Pereira F, Burges CJC, Bottou L, Weinberger KQ, editors. Advances in
Neural Information Processing Systems 25 (NIPS); 2012. p. 1097–1105.
12 Devlin J, Chang MW, Lee K, Toutanova K. BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding. In: Proceedings of the 2019 Conference of the
North American Chapter of the Association for Computational Linguistics: Human Language
Technologies, Volume 1 (Long and Short Papers); 2019. p. 4171–4186.
13 Silver D, Huang A, Maddison CJ, Guez A, Sifre L, van den Driessche G, et al. Mastering the 11
game of Go with deep neural networks and tree search. Nature. 2016;529:484–503.
14 Baxter J. A Model of Inductive Bias Learning. J Artif Int Res. 2000;12(1):149–198.

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
15 Bengio Y, Courville A, Vincent P. Representation Learning: A Review and New Perspectives.
IEEE Transactions on Pattern Analysis and Machine Intelligence. 2013;35(8):1798–1828.
16 Abadi M, Agarwal A, Barham P, Brevdo E, Chen Z, Citro C, et al.. TensorFlow: Large-Scale
Machine Learning on Heterogeneous Systems; 2015. Software available from tensorflow.org.
Available from: http://tensorflow.org/.
17 Paszke A, Gross S, Massa F, Lerer A, Bradbury J, Chanan G, et al. PyTorch: An Imperative Style,
High-Performance Deep Learning Library. In: Advances in Neural Information Processing
Systems 32; 2019. p. 8024–8035.
18 Hyndman RJ, Khandakar Y. Automatic time series forecasting: the forecast package for R.
Journal of Statistical Software. 2008;26(3):1–22.
19 Nadaraya EA. On Estimating Regression. Theory of Probability and Its Applications.
1964;9(1):141–142.
20 Smola AJ, Schölkopf B. A Tutorial on Support Vector Regression. Statistics and Computing.
2004;14(3):199–222.
21 Williams CKI, Rasmussen CE. Gaussian Processes for Regression. In: Advances in Neural
Information Processing Systems (NIPS); 1996. .
22 Damianou A, Lawrence N. Deep Gaussian Processes. In: Proceedings of the Conference on
Artificial Intelligence and Statistics (AISTATS); 2013. .
23 Garnelo M, Rosenbaum D, Maddison C, Ramalho T, Saxton D, Shanahan M, et al. Conditional
Neural Processes. In: Proceedings of the International Conference on Machine Learning (ICML);
2018. .
24 Waibel A. Modular Construction of Time-Delay Neural Networks for Speech Recognition.
Neural Comput. 1989;1(1):39–46.
25 Wan E. Time Series Prediction by Using a Connectionist Network with Internal Delay Lines. In:
Time Series Prediction. Addison-Wesley; 1994. p. 195–217.
26 Sen R, Yu HF, Dhillon I. Think Globally, Act Locally: A Deep Neural Network Approach to
High-Dimensional Time Series Forecasting. In: Advances in Neural Information Processing
Systems (NeurIPS); 2019. .
27 Wen R, Torkkola K. Deep Generative Quantile-Copula Models for Probabilistic Forecasting. In:
ICML Time Series Workshop; 2019. .
28 Li Y, Yu R, Shahabi C, Liu Y. Diffusion Convolutional Recurrent Neural Network: Data-
Driven Traffic Forecasting. In: (Proceedings of the International Conference on Learning
Representations ICLR); 2018. .
29 Ghaderi A, Sanandaji BM, Ghaderi F. Deep Forecast: Deep Learning-based Spatio-Temporal
Forecasting. In: ICML Time Series Workshop; 2017. .
30 Salinas D, Bohlke-Schneider M, Callot L, Medico R, Gasthaus J. High-dimensional multivariate
forecasting with low-rank Gaussian Copula Processes. In: Advances in Neural Information
Processing Systems (NeurIPS); 2019. .
31 Goodfellow I, Bengio Y, Courville A. Deep Learning. MIT Press; 2016. http://www.
deeplearningbook.org.
32 van den Oord A, Dieleman S, Zen H, Simonyan K, Vinyals O, Graves A, et al. WaveNet: A
Generative Model for Raw Audio. arXiv e-prints. 2016 Sep;p. arXiv:1609.03499.
33 Bai S, Zico Kolter J, Koltun V. An Empirical Evaluation of Generic Convolutional and Recurrent
Networks for Sequence Modeling. arXiv e-prints. 2018;p. arXiv:1803.01271.
34 Borovykh A, Bohte S, Oosterlee CW. Conditional Time Series Forecasting with Convolutional
Neural Networks. arXiv e-prints. 2017;p. arXiv:1703.04691.
35 Lyons RG. Understanding Digital Signal Processing (2nd Edition). USA: Prentice Hall PTR;
2004.
36 Young T, Hazarika D, Poria S, Cambria E. Recent Trends in Deep Learning Based
Natural Language Processing [Review Article]. IEEE Computational Intelligence Magazine.
2018;13(3):55–75.
37 Salinas D, Flunkert V, Gasthaus J. DeepAR: Probabilistic Forecasting with Autoregressive
Recurrent Networks. arXiv e-prints. 2017;p. arXiv:1704.04110.
38 Rangapuram SS, Seeger MW, Gasthaus J, Stella L, Wang Y, Januschowski T. Deep State Space
Models for Time Series Forecasting. In: Advances in Neural Information Processing Systems
(NIPS); 2018. .
39 Lim B, Zohren S, Roberts S. Recurrent Neural Filters: Learning Independent Bayesian Filtering 12
Steps for Time Series Prediction. In: International Joint Conference on Neural Networks (IJCNN);
2020. .

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
40 Wang Y, Smola A, Maddix D, Gasthaus J, Foster D, Januschowski T. Deep Factors for Forecasting.
In: Proceedings of the International Conference on Machine Learning (ICML); 2019. .
41 Elman JL. Finding structure in time. Cognitive Science. 1990;14(2):179 – 211.
42 Bengio Y, Simard P, Frasconi P. Learning long-term dependencies with gradient descent is
difficult. IEEE Transactions on Neural Networks. 1994;5(2):157–166.
43 Kolen JF, Kremer SC. In: Gradient Flow in Recurrent Nets: The Difficulty of Learning LongTerm
Dependencies; 2001. p. 237–243.
44 Hochreiter S, Schmidhuber J. Long Short-Term Memory. Neural Computation. 1997
Nov;9(8):1735–1780.
45 Srkk S. Bayesian Filtering and Smoothing. Cambridge University Press; 2013.
46 Kalman RE. A New Approach to Linear Filtering and Prediction Problems. Journal of Basic
Engineering. 1960;82(1):35.
47 Bahdanau D, Cho K, Bengio Y. Neural Machine Translation by Jointly Learning to Align and
Translate. In: Proceedings of the International Conference on Learning Representations (ICLR);
2015. .
48 Cho K, van Merriënboer B, Gulcehre C, Bahdanau D, Bougares F, Schwenk H, et al. Learning
Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation. In:
Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing
(EMNLP); 2014. .
49 Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, et al. Attention is All you
Need. In: Advances in Neural Information Processing Systems (NIPS); 2017. .
50 Dai Z, Yang Z, Yang Y, Carbonell J, Le Q, Salakhutdinov R. Transformer-XL: Attentive Language
Models beyond a Fixed-Length Context. In: Proceedings of the 57th Annual Meeting of the
Association for Computational Linguistics (ACL); 2019. .
51 Graves A, Wayne G, Danihelka I. Neural Turing Machines. CoRR. 2014;abs/1410.5401.
52 Fan C, Zhang Y, Pan Y, Li X, Zhang C, Yuan R, et al. Multi-Horizon Time Series Forecasting with
Temporal Attention Learning. In: Proceedings of the ACM SIGKDD international conference
on Knowledge discovery and data mining (KDD); 2019. .
53 Li S, Jin X, Xuan Y, Zhou X, Chen W, Wang YX, et al. Enhancing the Locality and Breaking
the Memory Bottleneck of Transformer on Time Series Forecasting. In: Advances in Neural
Information Processing Systems (NeurIPS); 2019. .
54 Lim B, Arik SO, Loeff N, Pfister T. Temporal Fusion Transformers for Interpretable Multi-horizon
Time Series Forecasting. arXiv e-prints. 2019;p. arXiv:1912.09363.
55 Choi E, Bahadori MT, Sun J, Kulas JA, Schuetz A, Stewart WF. RETAIN: An Interpretable
Predictive Model for Healthcare using Reverse Time Attention Mechanism. In: Advances in
Neural Information Processing Systems (NIPS); 2016. .
56 Wen R, et al. A Multi-Horizon Quantile Recurrent Forecaster. In: NIPS 2017 Time Series
Workshop; 2017. .
57 Taieb SB, Sorjamaa A, Bontempi G. Multiple-output modeling for multi-step-ahead time series
forecasting. Neurocomputing. 2010;73(10):1950 – 1957.
58 Marcellino M, Stock J, Watson M. A Comparison of Direct and Iterated Multistep AR Methods
for Forecasting Macroeconomic Time Series. Journal of Econometrics. 2006;135:499–526.
59 Makridakis S, Spiliotis E, Assimakopoulos V. Statistical and Machine Learning forecasting
methods: Concerns and ways forward. PLOS ONE. 2018 03;13(3):1–26.
60 Hyndman R. A brief history of forecasting competitions. International Journal of Forecasting.
2020;36(1):7–14.
61 The M4 Competition: 100,000 time series and 61 forecasting methods. International Journal of
Forecasting. 2020;36(1):54 – 74.
62 Fildes R, Hibon M, Makridakis S, Meade N. Generalising about univariate forecasting methods:
further empirical evidence. International Journal of Forecasting. 1998;14(3):339 – 358.
63 Makridakis S, Hibon M. The M3-Competition: results, conclusions and implications.
International Journal of Forecasting. 2000;16(4):451 – 476. The M3- Competition.
64 Smyl S. A hybrid method of exponential smoothing and recurrent neural networks for time
series forecasting. International Journal of Forecasting. 2020;36(1):75 – 85. M4 Competition.
65 Lim B, Zohren S, Roberts S. Enhancing Time-Series Momentum Strategies Using Deep Neural
Networks. The Journal of Financial Data Science. 2019;.
66 Grover A, Kapoor A, Horvitz E. A Deep Hybrid Model for Weather Forecasting. In: Proceedings 13
of the ACM SIGKDD international conference on knowledge discovery and data mining (KDD);
2015. .

rsta.royalsocietypublishing.org Phil. Trans. R. Soc. A 0000000


..................................................................
67 Binkowski M, Marti G, Donnat P. Autoregressive Convolutional Neural Networks for
Asynchronous Time Series. In: Proceedings of the International Conference on Machine
Learning (ICML); 2018. .
68 Moraffah R, Karami M, Guo R, Raglin A, Liu H. Causal Interpretability for Machine Learning –
Problems, Methods and Evaluation. arXiv e-prints. 2020;p. arXiv:2003.03934.
69 Chakraborty S, Tomsett R, Raghavendra R, Harborne D, Alzantot M, Cerutti F, et al.
Interpretability of deep learning models: A survey of results. In: 2017 IEEE SmartWorld
Conference Proceedings); 2017. p. 1–6.
70 Rudin C. Stop explaining black box machine learning models for high stakes decisions and use
interpretable models instead. Nature Machine Intelligence. 2019 May;1(5):206–215.
71 Ribeio M, Singh S, Guestrin C. "Why Should I Trust You?" Explaining the Predictions of Any
Classifier. In: KDD; 2016. .
72 Lundberg S, Lee SI. A Unified Approach to Interpreting Model Predictions. In: Advances in
Neural Information Processing Systems (NIPS); 2017. .
73 Simonyan K, Vedaldi A, Zisserman A. Deep Inside Convolutional Networks: Visualising Image
Classification Models and Saliency Maps. arXiv e-prints. 2013;p. arXiv:1312.6034.
74 Siddiqui SA, Mercier D, Munir M, Dengel A, Ahmed S. TSViz: Demystification of Deep Learning
Models for Time-Series Analysis. IEEE Access. 2019;7:67027–67040.
75 Koh PW, Liang P. Understanding Black-box Predictions via Influence Functions. In: Proceedings
of the International Conference on Machine Learning(ICML; 2017. .
76 Bai T, Zhang S, Egleston BL, Vucetic S. Interpretable Representation Learning for Healthcare
via Capturing Disease Progression through Time. In: Proceedings of the ACM SIGKDD
International Conference on Knowledge Discovery & Data Mining (KDD); 2018. .
77 Yoon J, Jordon J, van der Schaar M. GANITE: Estimation of Individualized Treatment Effects
using Generative Adversarial Nets. In: International Conference on Learning Representations
(ICLR); 2018. .
78 Hartford J, Lewis G, Leyton-Brown K, Taddy M. Deep IV: A Flexible Approach for
Counterfactual Prediction. In: Proceedings of the 34th International Conference on Machine
Learning (ICML); 2017. .
79 Alaa AM, Weisz M, van der Schaar M. Deep Counterfactual Networks with Propensity Dropout.
In: Proceedings of the 34th International Conference on Machine Learning (ICML); 2017. .
80 Mansournia MA, Etminan M, Danaei G, Kaufman JS, Collins G. Handling time varying
confounding in observational research. BMJ. 2017;359.
81 Lim B, Alaa A, van der Schaar M. Forecasting Treatment Responses Over Time Using Recurrent
Marginal Structural Networks. In: NeurIPS; 2018. .
82 Li R, Shahn Z, Li J, Lu M, Chakraborty P, Sow D, et al. G-Net: A Deep Learning Approach to
G-computation for Counterfactual Outcome Prediction Under Dynamic Treatment Regimes.
arXiv e-prints. 2020;p. arXiv:2003.10551.
83 Bica I, Alaa AM, Jordon J, van der Schaar M. Estimating counterfactual treatment outcomes
over time through adversarially balanced representations. In: International Conference on
Learning Representations(ICLR); 2020. .
84 Chen RTQ, Rubanova Y, Bettencourt J, Duvenaud D. Neural Ordinary Differential Equations.
In: Proceedings of the International Conference on Neural Information Processing Systems
(NIPS); 2018. .
85 Fry C, Brundage M. The M4 Forecasting Competition – A Practitioner’s View. International
Journal of Forecasting. 2019;.
Chapter 3

Novel Architectures For Time


Series Forecasting

Publications Included
• B. Lim, S. O. Arik, N. Loeff, T. Pfister. Temporal Fusion Transformers for In-
terpretable Multi-horizon Time Series Forecasting. Submitted, 2020. Weblink:
https://arxiv.org/abs/1912.09363.

• B. Lim, S. Zohren, S. Roberts. Recurrent Neural Filters: Learning Indepen-


dent Bayesian Filtering Steps for Time Series Prediction. Proceedings of the
International Joint Conference On Neural Networks (IJCNN), 2020. Weblink:
https://arxiv.org/abs/1901.08096.

22
Temporal Fusion Transformers
for Interpretable Multi-horizon Time Series Forecasting

Bryan Lima,1,∗, Sercan Ö. Arıkb , Nicolas Loeffb , Tomas Pfisterb


a University of Oxford, UK
b Google Cloud AI, USA

Abstract
Multi-horizon forecasting often contains a complex mix of inputs – including
static (i.e. time-invariant) covariates, known future inputs, and other exogenous
time series that are only observed in the past – without any prior information
on how they interact with the target. Several deep learning methods have been
proposed, but they are typically ‘black-box’ models which do not shed light on
how they use the full range of inputs present in practical scenarios. In this pa-
per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
based architecture which combines high-performance multi-horizon forecasting
with interpretable insights into temporal dynamics. To learn temporal rela-
tionships at different scales, TFT uses recurrent layers for local processing and
interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
cialized components to select relevant features and a series of gating layers to
suppress unnecessary components, enabling high performance in a wide range of
scenarios. On a variety of real-world datasets, we demonstrate significant per-
formance improvements over existing benchmarks, and showcase three practical
interpretability use cases of TFT.
Keywords: Deep learning, Interpretability, Time series, Multi-horizon
forecasting, Attention mechanisms, Explainable AI.

1. Introduction

Multi-horizon forecasting, i.e. the prediction of variables-of-interest at mul-


tiple future time steps, is a crucial problem within time series machine learning.
In contrast to one-step-ahead predictions, multi-horizon forecasts provide users
with access to estimates across the entire path, allowing them to optimize their
actions at multiple steps in future (e.g. retailers optimizing the inventory for

∗ Corresponding authors
Email addresses: [email protected] (Bryan Lim), [email protected] (Sercan Ö.
Arık), [email protected] (Nicolas Loeff), [email protected] (Tomas Pfister)
1 Completed as part of internship with Google Cloud AI Research.

Preprint submitted to Elsevier September 27, 2020


the entire upcoming season, or clinicians optimizing a treatment plan for a pa-
tient). Multi-horizon forecasting has many impactful real-world applications in
retail [1, 2], healthcare [3, 4] and economics [5]) – performance improvements
to existing methods in such applications are highly valuable.

Figure 1: Illustration of multi-horizon forecasting with static covariates, past-observed and


apriori-known future time-dependent inputs.

Practical multi-horizon forecasting applications commonly have access to


a variety of data sources, as shown in Fig. 1, including known information
about the future (e.g. upcoming holiday dates), other exogenous time series
(e.g. historical customer foot traffic), and static metadata (e.g. location of the
store) – without any prior knowledge on how they interact. This heterogeneity
of data sources together with little information about their interactions makes
multi-horizon time series forecasting particularly challenging.
Deep neural networks (DNNs) have increasingly been used in multi-horizon
forecasting, demonstrating strong performance improvements over traditional
time series models [6, 7, 8]. While many architectures have focused on variants
of recurrent neural network (RNN) architectures [9, 6, 10], recent improvements
have also used attention-based methods to enhance the selection of relevant time
steps in the past [11] – including Transformer-based models [12]. However, these
often fail to consider the different types of inputs commonly present in multi-
horizon forecasting, and either assume that all exogenous inputs are known
into the future [9, 6, 12] – a common problem with autoregressive models – or
neglect important static covariates [10] – which are simply concatenated with
other time-dependent features at each step. Many recent improvements in time
series models have resulted from the alignment of architectures with unique data
characteristics [13, 14]. We argue and demonstrate that similar performance
gains can also be reaped by designing networks with suitable inductive biases
for multi-horizon forecasting.
In addition to not considering the heterogeneity of common multi-horizon
forecasting inputs, most current architectures are ‘black-box’ models where fore-

2
casts are controlled by complex nonlinear interactions between many parame-
ters. This makes it difficult to explain how models arrive at their predictions,
and in turn makes it challenging for users to trust a model’s outputs and model
builders to debug it. Unfortunately, commonly-used explainability methods for
DNNs are not well-suited for applying to time series. In their conventional
form, post-hoc methods (e.g. LIME [15] and SHAP [16]) do not consider the
time ordering of input features. For example, for LIME, surrogate models are
independently constructed for each data-point, and for SHAP, features are con-
sidered independently for neighboring time steps. Such post-hoc approaches
would lead to poor explanation quality as dependencies between time steps are
typically significant in time series. On the other hand, some attention-based
architectures are proposed with inherent interpretability for sequential data,
primarily language or speech – such as the Transformer architecture [17]. The
fundamental caveat to apply them is that multi-horizon forecasting includes
many different types of input features, as opposed to language or speech. In
their conventional form, these architectures can provide insights into relevant
time steps for multi-horizon forecasting, but they cannot distinguish the impor-
tance of different features at a given timestep. Overall, in addition to the need
for new methods to tackle the heterogeneity of data in multi-horizon forecasting
for high performance, new methods are also needed to render these forecasts
interpretable, given the needs of the use cases.
In this paper we propose the Temporal Fusion Transformer (TFT) – an
attention-based DNN architecture for multi-horizon forecasting that achieves
high performance while enabling new forms of interpretability. To obtain signif-
icant performance improvements over state-of-the-art benchmarks, we introduce
multiple novel ideas to align the architecture with the full range of potential in-
puts and temporal relationships common to multi-horizon forecasting – specif-
ically incorporating (1) static covariate encoders which encode context vectors
for use in other parts of the network, (2) gating mechanisms throughout and
sample-dependent variable selection to minimize the contributions of irrelevant
inputs, (3) a sequence-to-sequence layer to locally process known and observed
inputs, and (4) a temporal self-attention decoder to learn any long-term depen-
dencies present within the dataset. The use of these specialized components
also facilitates interpretability; in particular, we show that TFT enables three
valuable interpretability use cases: helping users identify (i) globally-important
variables for the prediction problem, (ii) persistent temporal patterns, and (iii)
significant events. On a variety of real-world datasets, we demonstrate how
TFT can be practically applied, as well as the insights and benefits it provides.

2. Related Work

DNNs for Multi-horizon Forecasting: Similarly to traditional multi-


horizon forecasting methods [18, 19], recent deep learning methods can be cate-
gorized into iterated approaches using autoregressive models [9, 6, 12] or direct
methods based on sequence-to-sequence models [10, 11].

3
Iterated approaches utilize one-step-ahead prediction models, with multi-
step predictions obtained by recursively feeding predictions into future inputs.
Approaches with Long Short-term Memory (LSTM) [20] networks have been
considered, such as Deep AR [9] which uses stacked LSTM layers to generate pa-
rameters of one-step-ahead Gaussian predictive distributions. Deep State-Space
Models (DSSM) [6] adopt a similar approach, utilizing LSTMs to generate pa-
rameters of a predefined linear state-space model with predictive distributions
produced via Kalman filtering – with extensions for multivariate time series data
in [21]. More recently, Transformer-based architectures have been explored in
[12], which proposes the use of convolutional layers for local processing and a
sparse attention mechanism to increase the size of the receptive field during
forecasting. Despite their simplicity, iterative methods rely on the assumption
that the values of all variables excluding the target are known at forecast time
– such that only the target needs to be recursively fed into future inputs. How-
ever, in many practical scenarios, numerous useful time-varying inputs exist,
with many unknown in advance. Their straightforward use is hence limited for
iterative approaches. TFT, on the other hand, explicitly accounts for the di-
versity of inputs – naturally handling static covariates and (past-observed and
future-known) time-varying inputs.
In contrast, direct methods are trained to explicitly generate forecasts for
multiple predefined horizons at each time step. Their architectures typically rely
on sequence-to-sequence models, e.g. LSTM encoders to summarize past inputs,
and a variety of methods to generate future predictions. The Multi-horizon
Quantile Recurrent Forecaster (MQRNN) [10] uses LSTM or convolutional en-
coders to generate context vectors which are fed into multi-layer perceptrons
(MLPs) for each horizon. In [11] a multi-modal attention mechanism is used
with LSTM encoders to construct context vectors for a bi-directional LSTM
decoder. Despite performing better than LSTM-based iterative methods, inter-
pretability remains challenging for such standard direct methods. In contrast,
we show that by interpreting attention patterns, TFT can provide insightful
explanations about temporal dynamics, and do so while maintaining state-of-
the-art performance on a variety of datasets.
Time Series Interpretability with Attention: Attention mechanisms
are used in translation [17], image classification [22] or tabular learning [23]
to identify salient portions of input for each instance using the magnitude of
attention weights. Recently, they have been adapted for time series with inter-
pretability motivations [7, 12, 24], using LSTM-based [25] and transformer-based
[12] architectures. However, this was done without considering the importance
of static covariates (as the above methods blend variables at each input). TFT
alleviates this by using separate encoder-decoder attention for static features
at each time step on top of the self-attention to determine the contribution
time-varying inputs.
Instance-wise Variable Importance with DNNs: Instance (i.e. sample)-
wise variable importance can be obtained with post-hoc explanation methods
[15, 16, 26] and inherently intepretable models [27, 24]. Post-hoc explanation
methods, e.g. LIME [15], SHAP [16] and RL-LIM [26], are applied on pre-

4
trained black-box models and often based on distilling into a surrogate inter-
pretable model, or decomposing into feature attributions. They are not de-
signed to take into account the time ordering of inputs, limiting their use for
complex time series data. Inherently-interpretable modeling approaches build
components for feature selection directly into the architecture. For time series
forecasting specifically, they are based on explicitly quantifying time-dependent
variable contributions. For example, Interpretable Multi-Variable LSTMs [27]
partitions the hidden state such that each variable contributes uniquely to its
own memory segment, and weights memory segments to determine variable
contributions. Methods combining temporal importance and variable selection
have also been considered in [24], which computes a single contribution coeffi-
cient based on attention weights from each. However, in addition to the short-
coming of modelling only one-step-ahead forecasts, existing methods also focus
on instance-specific (i.e. sample-specific) interpretations of attention weights
– without providing insights into global temporal dynamics. In contrast, the
use cases in Sec. 7 demonstrate that TFT is able to analyze global temporal
relationships and allows users to interpret global behaviors of the model on the
whole dataset – specifically in the identification of any persistent patterns (e.g.
seasonality or lag effects) and regimes present.

3. Multi-horizon Forecasting

Let there be I unique entities in a given time series dataset – such as different
stores in retail or patients in healthcare. Each entity i is associated with a
set of static covariates si ∈ Rms , as well as inputs χi,t ∈ Rmχ and scalar
targets yi,t ∈ R at each time-step t ∈ [0, Ti ]. Time-dependent input features are
 T T T
subdivided into two categories χi,t = zi,t , xi,t – observed inputs zi,t ∈ R(mz )
which can only be measured at each step and are unknown beforehand, and
known inputs xi,t ∈ Rmx which can be predetermined (e.g. the day-of-week at
time t).
In many scenarios, the provision for prediction intervals can be useful for
optimizing decisions and risk management by yielding an indication of likely
best and worst-case values that the target can take. As such, we adopt quantile
regression to our multi-horizon forecasting setting (e.g. outputting the 10th ,
50th and 90th percentiles at each time step). Each quantile forecast takes the
form:
ŷi (q, t, τ ) = fq (τ, yi,t−k:t , zi,t−k:t , xi,t−k:t+τ , si ) , (1)
where ŷi,t+τ (q, t, τ ) is the predicted q th sample quantile of the τ -step-ahead
forecast at time t, and fq (.) is a prediction model. In line with other direct
methods, we simultaneously output forecasts for τmax time steps – i.e. τ ∈
{1, . . . , τmax }. We incorporate all past information within a finite look-back
window k, using target and known inputs only up till and including the forecast
start time t (i.e. yi,t−k:t = {yi,t−k , . . . , yi,t }) and known inputs across the entire

5
 2
range (i.e. xi,t−k:t+τ = xi,t−k , . . . , xi,t , . . . , xi,t+τ ).

4. Model Architecture

Figure 2: TFT architecture. TFT inputs static metadata, time-varying past inputs and time-
varying a priori known future inputs. Variable Selection is used for judicious selection of
the most salient features based on the input. Gated Residual Network blocks enable efficient
information flow with skip connections and gating layers. Time-dependent processing is based
on LSTMs for local processing, and multi-head attention for integrating information from any
time step.

We design TFT to use canonical components to efficiently build feature


representations for each input type (i.e. static, known, observed inputs) for high
forecasting performance on a wide range of problems. The major constituents
of TFT are:
1. Gating mechanisms to skip over any unused components of the architec-
ture, providing adaptive depth and network complexity to accommodate a
wide range of datasets and scenarios.
2. Variable selection networks to select relevant input variables at each time
step.
3. Static covariate encoders to integrate static features into the network,
through encoding of context vectors to condition temporal dynamics.
4. Temporal processing to learn both long- and short-term temporal rela-
tionships from both observed and known time-varying inputs. A sequence-
to-sequence layer is employed for local processing, whereas long-term depen-
dencies are captured using a novel interpretable multi-head attention block.

2 For notation simplicity, we omit the subscript i unless explicitly required.

6
5. Prediction intervals via quantile forecasts to determine the range of likely
target values at each prediction horizon.
Fig. 2 shows the high level architecture of Temporal Fusion Transformer
(TFT), with individual components described in detail in the subsequent sec-
tions.

4.1. Gating Mechanisms


The precise relationship between exogenous inputs and targets is often un-
known in advance, making it difficult to anticipate which variables are relevant.
Moreover, it is difficult to determine the extent of required non-linear process-
ing, and there may be instances where simpler models can be beneficial – e.g.
when datasets are small or noisy. With the motivation of giving the model the
flexibility to apply non-linear processing only where needed, we propose Gated
Residual Network (GRN) as shown in in Fig. 2 as a building block of TFT. The
GRN takes in a primary input a and an optional context vector c and yields:

GRNω (a, c) = LayerNorm (a + GLUω (η1 )) , (2)


η1 = W1,ω η2 + b1,ω , (3)
η2 = ELU (W2,ω a + W3,ω c + b2,ω ) , (4)

where ELU is the Exponential Linear Unit activation function [28], η1 ∈ Rdmodel , η2 ∈
Rdmodel are intermediate layers, LayerNorm is standard layer normalization of
[29], and ω is an index to denote weight sharing. When W2,ω a + W3,ω c +
b2,ω >> 0, the ELU activation would act as an identity function and when
W2,ω a + W3,ω c + b2,ω << 0, the ELU activation would generate a constant
output, resulting in linear layer behavior. We use component gating layers
based on Gated Linear Units (GLUs) [30] to provide the flexibility to suppress
any parts of the architecture that are not required for a given dataset. Letting
γ ∈ Rdmodel be the input, the GLU then takes the form:

GLUω (γ) = σ(W4,ω γ + b4,ω ) (W5,ω γ + b5,ω ), (5)

where σ(.) is the sigmoid activation function, W(.) ∈ Rdmodel ×dmodel , b(.) ∈
Rdmodel are the weights and biases, is the element-wise Hadamard product,
and dmodel is the hidden state size (common across TFT). GLU allows TFT
to control the extent to which the GRN contributes to the original input a –
potentially skipping over the layer entirely if necessary as the GLU outputs could
be all close to 0 in order to surpress the nonlinear contribution. For instances
without a context vector, the GRN simply treats the contex input as zero – i.e.
c = 0 in Eq. (4). During training, dropout is applied before the gating layer
and layer normalization – i.e. to η1 in Eq. (3).

4.2. Variable Selection Networks


While multiple variables may be available, their relevance and specific con-
tribution to the output are typically unknown. TFT is designed to provide

7
instance-wise variable selection through the use of variable selection networks
applied to both static covariates and time-dependent covariates. Beyond provid-
ing insights into which variables are most significant for the prediction problem,
variable selection also allows TFT to remove any unnecessary noisy inputs which
could negatively impact performance. Most real-world time series datasets con-
tain features with less predictive content, thus variable selection can greatly
help model performance via utilization of learning capacity only on the most
salient ones.
We use entity embeddings [31] for categorical variables as feature represen-
tations, and linear transformations for continuous variables – transforming each
input variable into a (dmodel )-dimensional vector which matches the dimensions
in subsequent layers for skip connections. All static, past and future inputs
make use of separate variable selection networks (as denoted by different colors
in Fig. 2). Without loss of generality, we present the variable selection network
for past inputs – noting that those for other inputs take the same form.
(j)
Let ξt ∈ Rdmodel denote the transformed input of the j-th variable at time
h iT
(1)T (m )T
t, with Ξt = ξt , . . . , ξt χ being the flattened vector of all past inputs
at time t. Variable selection weights are generated by feeding both Ξt and an
external context vector cs through a GRN, followed by a Softmax layer:

vχt = Softmax GRNvχ (Ξt , cs ) , (6)

where vχt ∈ Rmχ is a vector of variable selection weights, and cs is obtained


from a static covariate encoder (see Sec. 4.3). For static variables, we note
that the context vector cs is omitted – given that it already has access to static
information.
At each time step, an additional layer of non-linear processing is employed
(j)
by feeding each ξt through its own GRN:
 
(j) (j)
ξ̃t = GRNξ̃(j) ξt , (7)

(j)
where ξ̃t is the processed feature vector for variable j. We note that each vari-
able has its own GRNξ(j) , with weights shared across all time steps t. Processed
features are then weighted by their variable selection weights and combined:
Xm χ (j)
ξ̃t = vχ(j)
t
ξ̃t , (8)
j=1

(j)
where vχt is the j-th element of vector vχt .

4.3. Static Covariate Encoders


In contrast with other time series forecasting architectures, the TFT is care-
fully designed to integrate information from static metadata, using separate
GRN encoders to produce four different context vectors, cs , ce , cc , and ch .
These contect vectors are wired into various locations in the temporal fusion

8
decoder (Sec. 4.5) where static variables play an important role in processing.
Specifically, this includes contexts for (1) temporal variable selection (cs ), (2)
local processing of temporal features (cc , ch ), and (3) enriching of temporal fea-
tures with static information (ce ). As an example, taking ζ to be the output
of the static variable selection network, contexts for temporal variable selection
would be encoded according to cs = GRNcs (ζ).

4.4. Interpretable Multi-Head Attention


The TFT employs a self-attention mechanism to learn long-term relation-
ships across different time steps, which we modify from multi-head attention in
transformer-based architectures [17, 12] to enhance explainability. In general,
attention mechanisms scale values V ∈ RN ×dV based on relationships between
keys K ∈ RN ×dattn and queries Q ∈ RN ×dattn as below:

Attention(Q, K, V ) = A(Q, K)V , (9)

where A() is a normalization function. A common choice is scaled dot-product


attention [17]:
p
A(Q, K) = Softmax(QK T / dattn ). (10)

To improve the learning capacity of the standard attention mechanism,


multi-head attention is proposed in [17], employing different heads for differ-
ent representation subspaces:

MultiHead(Q, K, V ) = [H1 , . . . , HmH ] WH , (11)


(h) (h) (h)
Hh = Attention(Q WQ , K WK , V WV ), (12)
(h) (h) (h)
where WK ∈ Rdmodel ×dattn , WQ ∈ Rdmodel ×dattn , WV ∈ Rdmodel ×dV are
head-specific weights for keys, queries and values, and WH ∈ R(mH ·dV )×dmodel
linearly combines outputs concatenated from all heads Hh .
Given that different values are used in each head, attention weights alone
would not be indicative of a particular feature’s importance. As such, we modify
multi-head attention to share values in each head, and employ additive aggre-
gation of all heads:

InterpretableMultiHead(Q, K, V ) = H̃ WH , (13)

H̃ = Ã(Q, K) V WV , (14)
 XmH  
(h) (h)
= 1/H A Q W Q , K WK V WV , (15)
h=1
XmH (h) (h)
= 1/H Attention(Q WQ , K WK , V WV ), (16)
h=1

where WV ∈ Rdmodel ×dV are value weights shared across all heads, and WH ∈
Rdattn ×dmodel is used for final linear mapping. From Eq. (15), we see that each

9
head can learn different temporal patterns, while attending to a common set of
input features – which can be interpreted as a simple ensemble over attention
weights into combined matrix Ã(Q, K) in Eq. (14). Compared to A(Q, K) in
Eq. (10), Ã(Q, K) yields an increased representation capacity in an efficient
way.

4.5. Temporal Fusion Decoder


The temporal fusion decoder uses the series of layers described below to
learn temporal relationships present in the dataset:

4.5.1. Locality Enhancement with Sequence-to-Sequence Layer


In time series data, points of significance are often identified in relation
to their surrounding values – such as anomalies, change-points or cyclical pat-
terns. Leveraging local context, through the construction of features that utilize
pattern information on top of point-wise values, can thus lead to performance
improvements in attention-based architectures. For instance, [12] adopts a sin-
gle convolutional layer for locality enhancement – extracting local patterns us-
ing the same filter across all time. However, this might not be suitable for
cases when observed inputs exist, due to the differing number of past and fu-
ture inputs. As such, we propose the application of a sequence-to-sequence
model to naturally handle these differences – feeding ξ̃t−k:t into the encoder
and ξ̃t+1:t+τmax into the decoder. This then generates a set of uniform temporal
features which serve as inputs into the temporal fusion decoder itself – denoted
by φ(t, n) ∈ {φ(t, −k), . . . , φ(t, τmax )} with n being a position index. For com-
parability with commonly-used sequence-to-sequence baselines, we consider the
use of an LSTM encoder-decoder – although other models can potentially be
adopted as well. This also serves as a replacement for standard positional en-
coding, providing an appropriate inductive bias for the time ordering of the
inputs. Moreover, to allow static metadata to influence local processing, we use
the cc , ch context vectors from the static covariate encoders to initialize the cell
state and hidden state respectively for the first LSTM in the layer. We also
employ a gated skip connection over this layer:
 
φ̃(t, n) = LayerNorm ξ̃t+n + GLUφ̃ (φ(t, n)) , (17)

where n ∈ [−k, τmax ] is a position index.

4.5.2. Static Enrichment Layer


As static covariates often have a significant influence on the temporal dynam-
ics (e.g. genetic information on disease risk), we introduce a static enrichment
layer that enhances temporal features with static metadata. For a given position
index n, static enrichment takes the form:
 
θ(t, n) = GRNθ φ̃(t, n), ce , (18)

where the weights of GRNφ are shared across the entire layer, and ce is a context
vector from a static covariate encoder.

10
4.5.3. Temporal Self-Attention Layer
Following static enrichment, we next apply self-attention. All static-enriched
temporal features are first grouped into a single matrix – i.e. Θ(t) = [θ(t, −k), . . . ,
θ(t, τ )]T – and interpretable multi-head attention (see Sec. 4.4) is applied at
each forecast time (with N = τmax + k + 1):

B(t) = InterpretableMultiHead(Θ(t), Θ(t), Θ(t)), (19)

to yield B(t) = [β(t, −k), . . . , β(t, τmax )]. dV = dattn = dmodel /mH are cho-
sen, where mH is the number of heads. Decoder masking [17, 12] is applied to
the multi-head attention layer to ensure that each temporal dimension can only
attend to features preceding it. Besides preserving causal information flow via
masking, the self-attention layer allows TFT to pick up long-range dependen-
cies that may be challenging for RNN-based architectures to learn. Following
the self-attention layer, an additional gating layer is also applied to facilitate
training:

δ(t, n) = LayerNorm(θ(t, n) + GLUδ (β(t, n))). (20)

4.5.4. Position-wise Feed-forward Layer


We apply an additional non-linear processing to the outputs of the self-
attention layer. Similar to the static enrichment layer, this makes use of GRNs:

ψ(t, n) = GRNψ (δ(t, n)) , (21)

where the weights of GRNψ are shared across the entire layer. As per Fig. 2, we
also apply a gated residual connection which skips over the entire transformer
block, providing a direct path to the sequence-to-sequence layer – yielding a
simpler model if additional complexity is not required, as shown below:
 
ψ̃(t, n) = LayerNorm φ̃(t, n) + GLUψ̃ (ψ(t, n)) , (22)

4.6. Quantile Outputs


In line with previous work [10], TFT also generates prediction intervals on
top of point forecasts. This is achieved by the simultaneous prediction of various
percentiles (e.g. 10th , 50th and 90th ) at each time step. Quantile forecasts are
generated using linear transformation of the output from the temporal fusion
decoder:
ŷ(q, t, τ ) = Wq ψ̃(t, τ ) + bq , (23)
where Wq ∈ R1×d , bq ∈ R are linear coefficients for the specified quantile q.
We note that forecasts are only generated for horizons in the future – i.e. τ ∈
{1, . . . , τmax }.

11
5. Loss Functions

TFT is trained by jointly minimizing the quantile loss [10], summed across
all quantile outputs:
X X Xτmax QL (yt , ŷ(q, t − τ, τ ), q)
L(Ω, W ) = (24)
yt ∈Ω q∈Q τ =1 M τmax

QL(y, ŷ, q) = q(y − ŷ)+ + (1 − q)(ŷ − y)+ , (25)


where Ω is the domain of training data containing M samples, W represents the
weights of TFT, Q is the set of output quantiles (we use Q = {0.1, 0.5, 0.9} in
our experiments, and (.)+ = max(0, .). For out-of-sample testing, we evaluate
the normalized quantile losses across the entire forecasting horizon – focusing
on P50 and P90 risk for consistency with previous work [9, 6, 12]:
P Pτ
2 yt ∈Ω̃ τmax=1 QL (yt , ŷ(q, t − τ, τ ), q)
q-Risk = P Pτmax , (26)
yt ∈Ω̃ τ =1 |yt |

where Ω̃ is the domain of test samples. Full details on hyperparameter opti-


mization and training can be found in Appendix A.

6. Performance Evaluation

6.1. Datasets
We choose datasets to reflect commonly observed characteristics across a
wide range of challenging multi-horizon forecasting problems. To establish a
baseline and position with respect to prior academic work, we first evaluate
performance on the Electricity and Traffic datasets used in [9, 6, 12] – which
focus on simpler univariate time series containing known inputs only alongside
the target. Next, the Retail dataset helps us benchmark the model using the
full range of complex inputs observed in multi-horizon prediction applications
(see Sec. 3) – including rich static metadata and observed time-varying inputs.
Finally, to evaluate robustness to over-fitting on smaller noisy datasets, we
consider the financial application of volatility forecasting – using a dataset much
smaller than others. Broad descriptions of each dataset can be found below:
• Electricity: The UCI Electricity Load Diagrams Dataset, containing the
electricity consumption of 370 customers – aggregated on an hourly level as
in [32]. In accordance with [9], we use the past week (i.e. 168 hours) to
forecast over the next 24 hours.
• Traffic: The UCI PEM-SF Traffic Dataset describes the occupancy rate (with
yt ∈ [0, 1]) of 440 SF Bay Area freeways – as in [32]. It is also aggregated on
an hourly level as per the electricity dataset, with the same look back window
and forecast horizon.

12
• Retail: Favorita Grocery Sales Dataset from the Kaggle competition [33],
that combines metadata for different products and the stores, along with
other exogenous time-varying inputs sampled at the daily level. We forecast
log product sales 30 days into the future, using 90 days of past information.
• Volatility (or Vol.): The OMI realized library [34] contains daily realized
volatility values of 31 stock indices computed from intraday data, along with
their daily returns. For our experiments, we consider forecasts over the next
week (i.e. 5 business days) using information over the past year (i.e. 252
business days).

6.2. Training Procedure


For each dataset, we partition all time series into 3 parts – a training set
for learning, a validation set for hyperparameter tuning, and a hold-out test
set for performance evaluation. Hyperparameter optimization is conducted via
random search, using 240 iterations for Volatility, and 60 iterations for others.
Full search ranges for all hyperparameters are below, with datasets and optimal
model parameters listed in Table 1.

• State size – 10, 20, 40, 80, 160, 240, 320


• Dropout rate – 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9
• Minibatch size – 64, 128, 256
• Learning rate – 0.0001, 0.001, 0.01
• Max. gradient norm – 0.01, 1.0, 100.0
• Num. heads – 1, 4

To preserve explainability, we adopt only a single interpretable multi-head


attention layer. For ConvTrans [12], we use the same fixed stack size (3 layers)
and number of heads (8 heads) as in [12]. We keep the same attention model,
and treat kernel sizes for the convolutional processing layer as a hyperparameter
(∈ {1, 3, 6, 9}) – as optimal kernel sizes are observed to be dataset dependent
[12]. An open-source implementation of the TFT on these datasets can be found
on GitHub3 for full reproducibility.

6.3. Computational Cost


Across all datasets, each TFT model was also trained on a single GPU, and
can be deployed without the need for extensive computing resources. For in-
stance, using a NVIDIA Tesla V100 GPU, our optimal TFT model (for Electric-
ity dataset) takes just slightly over 6 hours to train (each epoch being roughly
52 mins). The batched inference on the entire validation dataset (consisting
50,000 samples) takes 8 minutes. TFT training and inference times can be
further reduced with hardware-specific optimizations.

13
Table 1: Information on dataset and optimal TFT configuration.

Electricity Traffic Retail Vol.


Dataset Details
Target Type R [0, 1] R R
Number of Entities 370 440 130k 41
Number of Samples 500k 500k 500k ∼100k
Network Parameters
k 168 168 90 252
τmax 24 24 30 5
Dropout Rate 0.1 0.3 0.1 0.3
State Size 160 320 240 160
Number of Heads 4 4 4 1
Training Parameters
Minibatch Size 64 128 128 64
Learning Rate 0.001 0.001 0.001 0.01
Max Gradient Norm 0.01 100 100 0.01

6.4. Benchmarks
We extensively compare TFT to a wide range of models for multi-horizon
forecasting, based on the categories described in Sec. 2. Hyperparameter op-
timization is conducted using random search over a pre-defined search space,
using the same number of iterations across all benchmarks for a give dataset.
Additional details are included in Appendix A.
Direct methods: As TFT falls within this class of multi-horizon models,
we primarily focus comparisons on deep learning models which directly generate
prediction at future horizons, including: 1) simple sequence-to-sequence models
with global contexts (Seq2Seq), and 2) the Multi-horizon Quantile Recurrent
Forecaster (MQRNN) [10].
Iterative methods: To position with respect to the rich body of work on
iterative models, we evaluate TFT using the same setup as [9] for the Electricity
and Traffic datasets. This extends the results from [12] for 1) DeepAR [9], 2)
DSSM [6], and 3) the Transformer-based architecture of [12] with local convo-
lutional processing – which refer to as ConvTrans. For more complex datasets,
we focus on the ConvTrans model given its strong outperformance over other
iterative models in prior work, and DeepAR due to its popularity among practi-
tioners. As models in this category require knowledge of all inputs in the future
to generate predictions, we accommodate this for complex datasets by imputing
unknown inputs with their last available value.
For simpler univariate datasets, we note that the results for ARIMA, ETS,
TRMF, DeepAR, DSSM and ConvTrans have been reproduced from [12] in

3 URL: https://github.com/google-research/google-research/tree/master/tft

14
Table 2: P50 and P90 quantile losses on a range of real-world datasets. Percentages in brackets
reflect the increase in quantile loss versus TFT (lower q-Risk better), with TFT outperforming
competing methods across all experiments, improving on the next best alternative method
(underlined) between 3% and 26%.

ARIMA ETS TRMF DeepAR DSSM


Electricity 0.154 (+180% ) 0.102 (+85% ) 0.084 (+53% ) 0.075 (+36% ) 0.083 (+51% )
Traffic 0.223 (+135% ) 0.236 (+148% ) 0.186 (+96% ) 0.161 (+69% ) 0.167 (+76% )
ConvTrans Seq2Seq MQRNN TFT
Electricity 0.059 (+7% ) 0.067 (+22% ) 0.077 (+40% ) 0.055*
Traffic 0.122 (+28% ) 0.105 (+11% ) 0.117 (+23% ) 0.095*
(a) P50 losses on simpler univariate datasets.

ARIMA ETS TRMF DeepAR DSSM


Electricity 0.102 (+278% ) 0.077 (+185% ) - 0.040 (+48% ) 0.056 (+107% )
Traffic 0.137 (+94% ) 0.148 (+110% ) - 0.099 (+40% ) 0.113 (+60% )
ConvTrans Seq2Seq MQRNN TFT
Electricity 0.034 (+26% ) 0.036 (+33% ) 0.036 (+33% ) 0.027*
Traffic 0.081 (+15% ) 0.075 (+6% ) 0.082 (+16% ) 0.070*
(b) P90 losses on simpler univariate datasets.

DeepAR CovTrans Seq2Seq MQRNN TFT


Vol. 0.050 (+28% ) 0.047 (+20% ) 0.042 (+7% ) 0.042 (+7% ) 0.039*
Retail 0.574 (+62% ) 0.429 (+21% ) 0.411 (+16% ) 0.379 (+7% ) 0.354*
(c) P50 losses on datasets with rich static or observed inputs.

DeepAR CovTrans Seq2Seq MQRNN TFT


Vol. 0.024 (+21% ) 0.024 (+22% ) 0.021 (+8% ) 0.021 (+9% ) 0.020*
Retail 0.230 (+56% ) 0.192 (+30% ) 0.157 (+7% ) 0.152 (+3% ) 0.147*
(d) P90 losses on datasets with rich static or observed inputs.

15
Table 2 for consistency.

6.5. Results and Discussion


Table 2 shows that TFT significantly outperforms all benchmarks over the
variety of datasets described in Sec. 6.1. For median forecasts, TFT yields
7% lower P50 and 9% lower P90 losses on average compared to the next best
model – demonstrating the benefits of explicitly aligning the architecture with
the general multi-horizon forecasting problem.
Comparing direct and iterative models, we observe the importance of ac-
counting for the observed inputs – noting the poorer results of ConvTrans on
complex datasets where observed input imputation is required (i.e. Volatility
and Retail). Furthermore, the benefits of quantile regression are also observed
when targets are not captured well by Gaussian distributions with direct models
outperforming in those scenarios. This can be seen, for example, from the Traf-
fic dataset where target distribution is significantly skewed – with more than
90% of occupancy rates falling between 0 and 0.1, and the remainder distributed
evenly until 1.0.

6.6. Ablation Analysis


To quantify the benefits of each of our proposed architectural contribution,
we perform an extensive ablation analysis – removing each component from the
network as below, and quantifying the percentage increase in loss versus the
original architecture:
• Gating layers: We ablate by replacing each GLU layer (Eq. (5)) with a
simple linear layer followed by ELU.
• Static covariate encoders: We ablate by setting all context vectors to zero
– i.e. cs =ce =cc =ch =0 – and concatenating all transformed static inputs to
all time-dependent past and future inputs.
• Instance-wise variable selection networks: We ablate by replacing the
softmax outputs of Eq. 6 with trainable coefficients, and removing the net-
works generating the variable selection weights. We retain, however, the
variable-wise GRNs (see Eq. (7)), maintaining a similar amount of non-linear
processing.
• Self-attention layers: We ablate by replacing the attention matrix of the
interpretable multi-head attention layer (Eq. 14) with a matrix of trainable
parameters WA – i.e. Ã(Q, K) = WA , where WA ∈ RN ×N . This prevents
TFT from attending to different input features at different times, helping
evaluation of the importance of instance-wise attention weights.
• Sequence-to-sequence layers for local processing: We ablate by re-
placing the sequence-to-sequence layer of Sec. 4.5.1 with standard positional
encoding used in [17].
Ablated networks are trained across for each dataset using the hyperparam-
eters of Table 1. Fig. 3 shows that the effects on both P50 and P90 losses
are similar across all datasets, with all components contributing to performance
improvements on the whole.

16
In general, the components responsible for capturing temporal relationships,
local processing and self-attention layers, have the largest impact on perfor-
mance, with P90 loss increases of > 6% on average and > 20% on select datasets
when ablated. The diversity across time series datasets can also be seen from the
differences in the ablation impact of the respective temporal components. Con-
cretely, while local processing is critical in Traffic, Retail and Volatility, lower
post-ablation P50 losses indicate that it can be detrimental in Electricity – with
the self-attention layer playing a more vital role. A possible explanation is that
persistent daily seasonality appears to dominate other temporal relationships in
the Electricity dataset. For this dataset, Table B.4 of Appendix B also shows
that the hour-of-day has the largest variable importance score across all tem-
poral inputs, exceeding even the target (i.e. Power Usage) itself. In contrast to
other dataset where past target observations are more significant (e.g. Traffic),
direct attention to previous days seem to help learning daily seasonal patterns
in Electricity – with local processing between adjacent time steps being less
necessary. We can account for this by treating the sequence-to-sequence archi-
tecture in the temporal fusion decoder as a hyperparameter to tune, including
an option for simple positional encoding without any local processing.
Static covariate encoders and instance-wise variable selection have the next
largest impact – increasing P90 losses by more than 2.6% and 4.1% on average.
The biggest benefits of these are observed for electricity dataset, where some of
the input features get very low importance.
Finally, gating layer ablation also shows increases in P90 losses, with a 1.9%
increase on average. This is the most significant on the volatility (with a 4.1%
P90 loss increase), underlying the benefit of component gating for smaller and
noisier datasets.

7. Interpretability Use Cases

Having established the performance benefits of our model, we next demon-


strate how our model design allows for analysis of its individual components
to interpret the general relationships it has learned. We demonstrate three in-
terpretability use cases: (1) examining the importance of each input variable
in prediction, (2) visualizing persistent temporal patterns, and (3) identifying
any regimes or events that lead to significant changes in temporal dynamics. In
contrast to other examples of attention-based interpretability [25, 12, 7] which
zoom in on interesting but instance-specific examples, our methods focus on
ways to aggregate the patterns across the entire dataset – extracting generaliz-
able insights about temporal dynamics.

7.1. Analyzing Variable Importance


We first quantify variable importance by analyzing the variable selection
weights described in Sec. 4.2. Concretely, we aggregate selection weights (i.e.
(j)
vχt in Eq. (8)) for each variable across our entire test set, recording the 10th ,
50th and 90th percentiles of each sampling distribution. As the Retail dataset

17
Table 3: Variable importance for the Retail dataset. The 10th , 50th and 90th percentiles of
the variable selection weights are shown, with values larger than 0.1 highlighted in purple.
For static covariates, the largest weights are attributed to variables which uniquely identify
different entities (i.e. item number and store number). For past inputs, past values of the
target (i.e. log sales) are critical as expected, as forecasts are extrapolations of past observa-
tions. For future inputs, promotion periods and national holidays have the greatest influence
on sales forecasts, in line with periods of increased customer spending.

10% 50% 90% 10% 50% 90%


Item Num 0.198 0.230 0.251 Transactions 0.029 0.033 0.037
Store Num 0.152 0.161 0.170 Oil 0.062 0.081 0.105
City 0.094 0.100 0.124 On-promotion 0.072 0.075 0.078
State 0.049 0.060 0.083 Day of Week 0.007 0.007 0.008
Type 0.005 0.006 0.008 Day of Month 0.083 0.089 0.096
Cluster 0.108 0.122 0.133 Month 0.109 0.122 0.136
Family 0.063 0.075 0.079 National Hol 0.131 0.138 0.145
Class 0.148 0.156 0.163 Regional Hol 0.011 0.014 0.018
Perishable 0.084 0.085 0.088 Local Hol 0.056 0.068 0.072
Open 0.027 0.044 0.067
Log Sales 0.304 0.324 0.353
(a) Static Covariates (b) Past Inputs

10% 50% 90%


On-promotion 0.155 0.170 0.182
Day of Week 0.029 0.065 0.089
Day of Month 0.056 0.116 0.138
Month 0.111 0.155 0.240
National Hol 0.145 0.220 0.242
Regional Hol 0.012 0.014 0.060
Local Hol 0.116 0.151 0.239
Open 0.088 0.095 0.097
(c) Future Inputs

18
(a) Changes in P50 losses across ablation tests

(b) Changes in P90 losses across ablation tests

Figure 3: Results of ablation analysis. Both a) and b) show the impact of ablation on the
P50 and P90 losses respectively. Results per dataset shown on the left, and the range across
datasets shown on the right. While the precise importance of each is dataset-specific, all
components contribute significantly on the whole – with the maximum percentage increase
over all datasets ranging from 3.6% to 23.4% for P50 losses, and similarly from 4.1% to 28.4%
for P90 losses.

contains the full set of available input types (i.e. static metadata, known inputs,
observed inputs and the target), we present the results for its variable impor-
tance analysis in Table 3. We also note similar findings in other datasets, which
are documented in Appendix B.1 for completeness. On the whole, the results
show that the TFT extracts only a subset of key inputs that intuitively play a
significant role in predictions. The analysis of persistent temporal patterns is
often key to understanding the time-dependent relationships present in a given
dataset. For instance, lag models are frequently adopted to study length of time
required for an intervention to take effect [35] – such as the impact of a govern-
ment’s increase in public expenditure on the resultant growth in Gross National
Product [36]. Seasonality models are also commonly used in econometrics to
identify periodic patterns in a target-of-interest [37] and measure the length of
each cycle. From a practical standpoint, model builders can use these insights to
further improve the forecasting model – for instance by increasing the receptive
field to incorporate more history if attention peaks are observed at the start of
the lookback window, or by engineering features to directly incorporate seasonal
effects. As such, using the attention weights present in the self-attention layer of
the temporal fusion decoder, we present a method to identify similar persistent
patterns – by measuring the contributions of features at fixed lags in the past
on forecasts at various horizons. Combining Eq. (14) and (19), we see that the
self-attention layer contains a matrix of attention weights at each forecast time
t – i.e. Ã(φ(t), φ(t)). Multi-head attention outputs at each forecast horizon τ

19
(i.e. β(t, τ )) can then be described as an attention-weighted sum of lower level
features at each position n:
Xτmax
β(t, τ ) = α(t, n, τ ) θ̃(t, n), (27)
n=−k

where α(t, n, τ ) is the (τ, n)-th element of Ã(φ(t), φ(t)), and θ̃(t, n) is a row
of Θ̃(t) = Θ(t)WV . Due to decoder masking, we also note that α(t, i, j) = 0,
∀i > j. For each forecast horizon τ , the importance of a previous time point
n < τ can hence be determined by analyzing distributions of α(t, n, τ ) across
all time steps and entities.

7.2. Visualizing Persistent Temporal Patterns


Attention weight patterns can be used to shed light on the most important
past time steps that the TFT model bases its decisions on. In contrast to other
traditional and machine learning time series methods, which rely on model-
based specifications for seasonality and lag analysis, the TFT can learn such
patterns from raw training data.
Fig. 4 shows the attention weight patterns across all our test datasets –
with the upper graph plotting the mean along with the 10th , 50th and 90th
percentiles of the attention weights for one-step-ahead forecasts (i.e. α(t, 1, τ ))
over the test set, and the bottom graph plotting the average attention weights
for various horizons (i.e. τ ∈ {5, 10, 15, 20}). We observe that the three datasets
exhibit a seasonal pattern, with clear attention spikes at daily intervals observed
for Electricity and Traffic, and a slightly weaker weekly patterns for Retail.
For Retail, we also observe the decaying trend pattern, with the last few days
dominating the importance.
No strong persistent patterns were observed for the Volatility – attention
weights equally distributed across all positions on average. This resembles a
moving average filter at the feature level, and – given the high degree of ran-
domness associated with the volatility process – could be useful in extracting
the trend over the entire period by smoothing out high-frequency noise.
TFT learns these persistent temporal patterns from the raw training data
without any human hard-coding. Such capability is expected to be very useful
in building trust with human experts via sanity-checking. Model developers can
also use these towards model improvements, e.g. via specific feature engineering
or data collection.

7.3. Identifying Regimes & Significant Events


Identifying sudden changes in temporal patterns can also be very useful,
as temporary shifts can occur due to the presence of significant regimes or
events. For instance, regime-switching behavior has been widely documented in
financial markets [38], with returns characteristics – such as volatility – being
observed to change abruptly between regimes. As such, identifying such regime
changes provides strong insights into the underlying problem which is useful for
identification of the significant events.

20
(a) Electricity (b) Traffic

(c) Retail (d) Volatility

Figure 4: Persistent temporal patterns across datasets. Clear seasonality observed for the
Electricity, Traffic and Retail datasets, but no strong persistent patterns seen in Volatility
dataset. Upper plot – percentiles of attention weights for one-step-ahead forecast. Lower plot
– average attention weights for forecast at various horizons.

21
Figure 5: Regime identification for S&P 500 realized volatility. Significant deviations in
attention patterns can be observed around periods of high volatility – corresponding to the
peaks observed in dist(t). We use a threshold of dist(t) > 0.3 to denote significant regimes, as
highlighted in purple. Focusing on periods around the 2008 financial crisis, the top right plot
visualizes α(t, n, 1) midway through the significant regime, compared to the normal regime on
the top left.

Firstly, for a given entity, we define the average attention pattern per forecast
horizon as: XT
ᾱ(n, τ ) = α(t, j, τ )/T, (28)
t=1

and then construct ᾱ(τ ) = [ᾱ(−k, τ ), . . . , ᾱ(τmax , τ )]T . To compare similarities


between attention weight vectors, we use the distance metric proposed by [39]:
p
κ(p, q) = 1 − ρ(p, q), (29)
P √
where ρ(p, q) = j pj qj is the Bhattacharya coefficient [40] measuring the
overlap between discrete distributions – with pj , qj being elements of probability
vectors p, q respectively. For each entity, significant shifts in temporal dynamics
are then measured using the distance between attention vectors at each point
with the average pattern, aggregated for all horizons as below:
Xτmax 
dist(t) = κ ᾱ(τ ), α(t, τ ) /τmax , (30)
τ =1

where α(t, τ ) = [α(t, −k, τ ), . . . , α(t, τmax , τ )]T .


Using the volatility dataset, we attempt to analyse regimes by applying our
distance metric to the attention patterns for the S&P 500 index over our train-
ing period (2001 to 2015). Plotting dist(t) against the target (i.e. log realized
volatility) in the bottom chart of Fig. 5, significant deviations in attention pat-
terns can be observed around periods of high volatility (e.g. the 2008 financial
crisis) – corresponding to the peaks observed in dist(t). From the plots, we can

22
see that TFT appears to alter its behaviour between regimes – placing equal at-
tention across past inputs when volatility is low, while attending more to sharp
trend changes during high volatility periods – suggesting differences in temporal
dynamics learned in each of these cases.

8. Conclusions

We introduce TFT, a novel attention-based deep learning model for in-


terpretable high-performance multi-horizon forecasting. To handle static co-
variates, a priori known inputs, and observed inputs effectively across wide
range of multi-horizon forecasting datasets, TFT uses specialized components.
Specifically, these include: (1) sequence-to-sequence and attention based tempo-
ral processing components that capture time-varying relationships at different
timescales, (2) static covariate encoders that allow the network to condition
temporal forecasts on static metadata, (3) gating components that enable skip-
ping over unnecessary parts of the network, (4) variable selection to pick rel-
evant input features at each time step, and (5) quantile predictions to obtain
output intervals across all prediction horizons. On a wide range of real-world
tasks – on both simple datasets that contain only known inputs and complex
datasets which encompass the full range of possible inputs – we show that TFT
achieves state-of-the-art forecasting performance. Lastly, we investigate the gen-
eral relationships learned by TFT through a series of interpretability use cases
– proposing novel methods to use TFT to (i) analyze important variables for a
given prediction problem, (ii) visualize persistent temporal relationships learned
(e.g. seasonality), and (iii) identify significant regimes changes.

9. Acknowledgements

The authors gratefully acknowledge discussions with Yaguang Li, Maggie


Wang, Jeffrey Gu, Minho Jin and Andrew Moore that contributed to the devel-
opment of this paper.

23
References
[1] J.-H. Böse, et al., Probabilistic demand forecasting at scale, Proc. VLDB Endow. 10 (12)
(2017) 1694–1705.

[2] P. Courty, H. Li, Timing of seasonal sales, The Journal of Business 72 (4) (1999) 545–572.

[3] B. Lim, A. Alaa, M. van der Schaar, Forecasting treatment responses over time using
recurrent marginal structural networks, in: NeurIPS, 2018.

[4] J. Zhang, K. Nawata, Multi-step prediction for influenza outbreak by an adjusted long
short-term memory, Epidemiology and infection 146 (7) (2018).

[5] C. Capistran, C. Constandse, M. Ramos-Francia, Multi-horizon inflation forecasts using


disaggregated data, Economic Modelling 27 (3) (2010) 666 – 677.

[6] S. S. Rangapuram, et al., Deep state space models for time series forecasting, in: NIPS,
2018.

[7] A. Alaa, M. van der Schaar, Attentive state-space modeling of disease progression, in:
NIPS, 2019.

[8] S. Makridakis, E. Spiliotis, V. Assimakopoulos, The m4 competition: 100,000 time series


and 61 forecasting methods, International Journal of Forecasting 36 (1) (2020) 54 – 74.

[9] D. Salinas, V. Flunkert, J. Gasthaus, T. Januschowski, DeepAR: Probabilistic forecasting


with autoregressive recurrent networks, International Journal of Forecasting (2019).

[10] R. Wen, et al., A multi-horizon quantile recurrent forecaster, in: NIPS 2017 Time Series
Workshop, 2017.

[11] C. Fan, et al., Multi-horizon time series forecasting with temporal attention learning, in:
KDD, 2019.

[12] S. Li, et al., Enhancing the locality and breaking the memory bottleneck of transformer
on time series forecasting, in: NeurIPS, 2019.

[13] J. Koutnı́k, K. Greff, F. Gomez, J. Schmidhuber, A clockwork rnn, in: ICML, 2014.

[14] D. Neil, et al., Phased lstm: Accelerating recurrent network training for long or event-
based sequences, in: NIPS, 2016.

[15] M. Ribeiro, et al., ”why should i trust you?” explaining the predictions of any classifier,
in: KDD, 2016.

[16] S. Lundberg, S.-I. Lee, A unified approach to interpreting model predictions, in: NIPS,
2017.

[17] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. u. Kaiser,


I. Polosukhin, Attention is all you need, in: NIPS, 2017.

[18] S. B. Taieb, A. Sorjamaa, G. Bontempi, Multiple-output modeling for multi-step-ahead


time series forecasting, Neurocomputing 73 (10) (2010) 1950 – 1957.

[19] M. Marcellino, J. Stock, M. Watson, A comparison of direct and iterated multistep ar


methods for forecasting macroeconomic time series, Journal of Econometrics 135 (2006)
499–526.

[20] S. Hochreiter, J. Schmidhuber, Long short-term memory, Neural Computation 9 (8)


(1997) 1735–1780.

24
[21] Y. Wang, et al., Deep factors for forecasting, in: ICML, 2019.

[22] F. Wang, M. Jiang, C. Qian, S. Yang, C. Li, H. Zhang, X. Wang, X. Tang, Residual
attention network for image classification, in: CVPR, 2017.

[23] S. O. Arik, T. Pfister, Tabnet: Attentive interpretable tabular learning (2019). arXiv:
1908.07442.

[24] E. Choi, et al., Retain: An interpretable predictive model for healthcare using reverse
time attention mechanism, in: NIPS, 2016.

[25] H. Song, et al., Attend and diagnose: Clinical time series analysis using attention models,
2018.

[26] J. Yoon, S. O. Arik, T. Pfister, Rl-lim: Reinforcement learning-based locally interpretable


modeling (2019). arXiv:1909.12367.

[27] T. Guo, T. Lin, N. Antulov-Fantulin, Exploring interpretable LSTM neural networks


over multi-variable data, in: ICML, 2019.

[28] D.-A. Clevert, T. Unterthiner, S. Hochreiter, Fast and accurate deep network learning
by exponential linear units (ELUs), in: ICLR, 2016.

[29] J. Lei Ba, J. R. Kiros, G. E. Hinton, Layer Normalization, arXiv:1607.06450 (Jul 2016).
arXiv:1607.06450.

[30] Y. Dauphin, A. Fan, M. Auli, D. Grangier, Language modeling with gated convolutional
networks, in: ICML, 2017.

[31] Y. Gal, Z. Ghahramani, A theoretically grounded application of dropout in recurrent


neural networks, in: NIPS, 2016.

[32] H.-F. Yu, N. Rao, I. S. Dhillon, Temporal regularized matrix factorization for high-
dimensional time series prediction, in: NIPS, 2016.

[33] C. Favorita, Corporacion favorita grocery sales forecasting competition (2018).


URL https://www.kaggle.com/c/favorita-grocery-sales-forecasting/

[34] G. Heber, A. Lunde, N. Shephard, K. K. Sheppard, Oxford-man institute’s realized


library (2009).
URL https://realized.oxford-man.ox.ac.uk/

[35] S. Du, G. Song, L. Han, H. Hong, Temporal causal inference with time lag, Neural
Computation 30 (1) (2018) 271–291.

[36] B. Baltagi, Distributed Lags and Dynamic Models, 2008, pp. 129–145.

[37] S. Hylleberg (Ed.), Modelling Seasonality, Oxford University Press, 1992.

[38] A. Ang, A. Timmermann, Regime changes and financial markets, Annual Review of
Financial Economics 4 (1) (2012) 313–337.

[39] D. Comaniciu, V. Ramesh, P. Meer, Kernel-based object tracking, IEEE Transactions on


Pattern Analysis and Machine Intelligence 25 (5) (2003) 564–577.

[40] T. Kailath, The divergence and bhattacharyya distance measures in signal selection,
IEEE Transactions on Communication Technology 15 (1) (1967) 52–60.

[41] E. Giovanis, The turn-of-the-month-effect: Evidence from periodic generalized autore-


gressive conditional heteroskedasticity (pgarch) model, International Journal of Economic
Sciences and Applied Research 7 (2014) 43–61.

25
APPENDIX

Appendix A. Dataset and Training Details


We provide all the sufficient information on feature pre-processing and train/test splits to
ensure reproducibility of our results.
Electricity: Per [9], we use 500k samples taken between 2014-01-01 to 2014-09-01 – using
the first 90% for training, and the last 10% as a validation set. Testing is done over the 7 days
immediately following the training set – as described in [9, 32]. Given the large differences
in magnitude between trajectories, we also apply z-score normalization separately to each
entity for real-valued inputs. In line with previous work, we consider the electricity usage,
day-of-week, hour-of-day and and a time index – i.e. the number of time steps from the first
observation – as real-valued inputs, and treat the entity identifier as a categorical variable.
Traffic: Tests on the Traffic dataset are also kept consistent with previous work, using
500k training samples taken before 2008-06-15 as per [9], and split in the same way as the
Electricity dataset. For testing, we use the 7 days immediately following the training set, and
z-score normalization was applied across all entities. For inputs, we also take traffic occupancy,
day-of-week, hour-of-day and and a time index as real-valued inputs, and the entity identifier
as a categorical variable.
Retail: We treat each product number-store number pair as a separate entity, with over
135k entities in total. The training set is made up of 450k samples taken between 2015-01-01
to 2015-12-01, validation set of 50k samples from the 30 days after the training set, and test set
of all entities over the 30-day horizon following the validation set. We use all inputs from the
Kaggle competition. Data is resampled at regular daily intervals, imputing any missing days
using the last available observation. We include an additional ’open’ flag to denote whether
data is present on a given day. We group national, regional, and local holidays into separate
variables. We apply a log-transform on the sales data, and adopt z-score normalization across
all entities. We consider log sales, transactions, oil to be real-valued and the rest to be
categorical.
Volatility: We use the data from 2000-01-03 to 2019-06-28 – with the training set con-
sisting of data before 2016, the validation set from 2016-2017, and the test set data from 2018
onwards. For the target, we focus on 5-min sub-sampled realized volatility (i.e. the rv5 ss
column ), and add daily open-to-close returns as an extra exogenous input. Additional vari-
ables are included for the day-of-week, day-of-month, week-of-year, and month – along with
a ’region’ variable for each index (i.e. Americas, Europe or Asia). Finally, a time index is
added to denote the number of days from the first day in the training set. We treat all date-
related variables (i.e. day-of-week, day-of-month, week-of-year, and month) and the region as
categorical inputs. A log transformation is applied to the target, and all inputs are z-score
normalized across all entities.

Appendix B. Interpretability Results


Apart from Sec. 7, which highlights our most prominent findings, we present the remaining
results here for completeness.

Appendix B.1. Variable Importance


Table B.4 shows the variable importance scores for the remaining Electricity, Traffic and
Volatility datasets. As these datasets only have one static input, the network allocates full
weight to the entity identifier for Electricity and Traffic, along with the region input for
Volatility. We also observe two types of important time-dependent inputs – those related to
past values of the target as before, and those related to calendar effects. For instance, the hour-
of-day plays a significant roles for Electricity and Traffic datasets, echoing the daily seasonality
observed in the next section. In the Volatility dataset, the day-of-month is observed to play
a significant role in future inputs – potentially reflecting turn-of-month effects [41].

26
Table B.4: Variable importance scores for the Electricity, Traffic and Volatility datasets. The
most significant variable of each input category is highlighted in purple. As before, past values
of the target play a significant role – being the top 1 or 2 most significant past input across
datasets. The role of seasonality can also be seen in Electricity and Traffic, where the past
and future values of the hour-of-day is important for forecasts.

10% 50% 90% 10% 50% 90%


Static Static
ID 1.000 1.000 1.000 ID 1.000 1.000 1.000
Past Past
Hour of Day 0.437 0.462 0.473 Hour of Day 0.285 0.296 0.300
Day of Week 0.078 0.099 0.151 Day of Week 0.117 0.122 0.124
Time Index 0.066 0.077 0.092 Time Index 0.107 0.109 0.111
Power Usage 0.342 0.359 0.366 Occupancy 0.471 0.473 0.483
Future Future
Hour of Day 0.718 0.738 0.739 Hour of Day 0.781 0.781 0.781
Day of Week 0.109 0.124 0.166 Day of Week 0.099 0.100 0.102
Time Index 0.114 0.137 0.155 Time Index 0.117 0.119 0.121
(a) Electricity (b) Traffic

10% 50% 90%


Static
Region 1.000 1.000 1.000
Past
Time Index 0.093 0.098 0.142
Day of Week 0.003 0.004 0.004
Day of Month 0.017 0.027 0.028
Week of Year 0.022 0.057 0.068
Month 0.008 0.009 0.011
Open-to-close Returns 0.078 0.158 0.178
Realised Vol 0.620 0.647 0.714
Future
Time Index 0.011 0.014 0.024
Day of Week 0.019 0.072 0.299
Day of Month 0.069 0.635 0.913
Week of Year 0.026 0.060 0.227
Month 0.008 0.055 0.713
(c) Volatility

27
Recurrent Neural Filters: Learning Independent
Bayesian Filtering Steps for Time Series Prediction
Bryan Lim, Stefan Zohren and Stephen Roberts
Oxford-Man Institute of Quantitative Finance
Department of Engineering Science
University of Oxford
Oxford, UK
{blim,zohren,sjrob}@robots.ox.ac.uk

Abstract—Despite the recent popularity of deep generative encapsulating the generative form of the state space model –
state space models, few comparisons have been made between implicitly condensing both state transition and update steps
network architectures and the inference steps of the Bayesian fil- into a single representation learnt by the RVAE decoder – and
tering framework – with most models simultaneously approximat-
ing both state transition and update steps with a single recurrent make it impossible to decouple the Bayes filter steps.
neural network (RNN). In this paper, we introduce the Recurrent Recent works in deep generative modelling have focused
Neural Filter (RNF), a novel recurrent autoencoder architecture
that learns distinct representations for each Bayesian filtering on the use of neural networks to learn independent factors
step, captured by a series of encoders and decoders. Testing of variation in static datasets – through the encouragement
this on three real-world time series datasets, we demonstrate of disentangled representations [16], [17] or by learning
that the decoupled representations learnt improve the accuracy causal mechanisms [18], [19]. While a wide range of training
of one-step-ahead forecasts while providing realistic uncertainty procedures and loss functions have been proposed [20], methods
estimates, and also facilitate multistep prediction through the
separation of encoder stages. in general use dedicated network components to learn distinct
Index Terms—recurrent neural networks, Bayesian filtering, interpretable relationships – ranging from orthogonalising
variational autoencoders, multistep forecasting latent representations in variational autoencoders [21] to
learning independent modules for different causal pathways
I. I NTRODUCTION [18]. By understanding the relationships encapsulated by each
component, we can subsequently decouple them for use in
Bayesian filtering [1] has been extensively used within the
related tasks – allowing the learnt mechanisms to generalise
domain of time series prediction, with numerous applications
to novel domains [18], [22] or to provide building blocks for
across different fields – including target tracking [2], robotics
transfer learning [21].
[3], finance [4], and medicine [5]. Performing inference via
a series of prediction and update steps [6], Bayesian filters In this paper, we introduce the Recurrent Neural Filter
recursively update the posterior distribution of predictions – or (RNF) – a novel recurrent autoencoder architecture which
the belief state [7] – with the arrival of new data. For many aligns network modules (encoders and decoders) with the
filter models – such as the Kalman filter [8] and the unscented inference steps of the Bayes filter – making several contri-
Kalman filter [9] – deterministic functions are used at each butions over standard approaches. Firstly, we propose a new
step to adjust the sufficient statistics of the belief state, guided training procedure to encourage independent representations
by generative models of the data. Each function quantifies within each module, by directly training intermediate encoders
the impact of different sources of information on latent state with a common emission decoder. In doing so, we augment
estimates – specifically time evolution and exogenous inputs in the loss function with additional regularisation terms (see
the prediction step, and realised observations in the update step. Section V), and directly encourage each encoder to learn
On top of efficient inference and uncertainty estimation, this functions to update the filter’s belief state given available
decomposition of inference steps enables Bayes filters to be information. Furthermore, to encourage the decoupling of
deployed in use cases beyond basic one-step-ahead prediction encoder stages, we randomly drop out the input dynamics
– with simple extensions for multistep prediction [10] and and error correction encoders during training – which can be
prediction in the presence of missing observations [11]. viewed as artificially introducing missingness to the inputs and
With the increasing use of deep neural networks for observations respectively. Finally, we highlight performance
time series prediction, applications of recurrent variational gains for one-step-ahead predictions through experiments on
autoencoder (RVAE) architectures have been investigated for 3 real-world time series datasets, and investigate multistep
forecasting non-linear state space models [12]–[15]. Learning predictions as a use case for generalising the RNF’s decoupled
dynamics directly from data, they avoid the need for explicit representations to other tasks – demonstrating performance
model specification – overcoming a key limitation in standard improvements from the recursive application of the state
Bayes filters. However, these RVAEs focus predominantly on transition encoders alone.
II. R ELATED W ORK representations across encoders and decoders.
Autoregressive Architectures: An alternative approach
RVAEs for State Space Modelling: The work of to deep generative modelling focuses on the autoregres-
[12] identifies close parallels between RNNs and latent state sive factorisationQ of the joint distribution of observations
space models, both consisting of an internal hidden state that (i.e. p(y1:T ) = t p(yt |y1:t )), directly generating the condi-
drives output forecasts and observations. Using an RVAE tional distribution at each step. For instance, WaveNet [27]
architecture described as a variational RNN (VRNN), they and Transformer [28], [29] networks use dilated CNNs and
build their recognition network (encoder) with RNNs and attention-based models to build predictive distributions. While
produce samples for the stochastic hidden state at each time successful in speech generation and language applications, these
point. Deep Kalman filters (DKFs) [13], [14] take this a step models suffer from several limitations in the context of time
further by allowing for exogenous inputs in their network and series prediction. Firstly, the CNN and attention models require
incorporating a special KL loss term to penalise state transitions the pre-specification of the amount of relevant history to use in
between time steps. Deep Variational Bayes Filters (DVBFs) predictions – with the size of the look-back window controlled
[15] enhance the interpretability of DKFs by modelling state by the length of the receptive field or extended context – which
transitions with parametric – e.g. linear – models, which take may be difficult when the data generating process is unknown.
in stochastic samples from the recognition model as inputs. Furthermore, they also rely on a discretisation of the output,
In general, while the above models capture the generative generating probabilities of occurrence within each discrete
modelling aspects of the state space framework, their inference interval using a softmax layer. This can create generalisation
procedure blends both state transition and error correction issues for time series where outputs are unbounded. In contrast,
steps, obliging the recognition model to learn representations the LSTM cells used in the RNF recognition model remove
for both simultaneously. In contrast, the RNF uses separate the need to define a look-back window, and the parametric
neural network components to directly model the Bayes filter distributions used for outputs are compatible with unbounded
steps – leading to improvements in representation learning and continuous observations.
enhanced predictive performance in time series applications. In other works, the use of RNNs in autoregressive architec-
Hybrid Approaches: In [23], the authors take a hybrid tures for time series prediction have been explored in DeepAR
approach with the structured variational autoencoder (SVAE), models [30], where LSTM networks output Gaussian mean
proposing an efficient general inference framework that com- and standard deviation parameters of predictive distributions at
bines probabilistic graphical models for the latent state with each step. We include this as a benchmark in our tests, noting
neural network observation models. This is similar in spirit the improvements observed with the RNF through its alignment
to the Kernel Kalman Filter [24], allowing for predictions with the Bayesian filtering paradigm.
to be made on complex observational datasets – such as Predictive State Representations: Predictive state
raw images – by encoding high dimensional outputs onto RNNs (PSRNN) [31]–[33] use an alternative formulation of the
a lower dimensional latent representation modelled with a Bayes filter, utilising a state representation that corresponds to
dynamical systems model. Although SVAEs provide a degree the statistics of the predictive distribution of future observations.
of interpretability to temporal dynamics, they also require a Predictions are made using a two-stage regression approach
parametric model to be defined for the latent states which modelled by their proposed architectures. Compared to alter-
may be challenging for arbitrary time series datasets. The native approaches, PSRNNs only produce point estimates for
RNF, in comparison, can learn the relationships directly from their forecasts – lacking the uncertainty bounds from predictive
data, without the need for explicit model specification. The distributions produced by the RNF.
Kalman variational autoencoder (KVAE) [25] extends ideas Non-Parametric State Space Models: Gaussian Process
from the SVAE, modelling latent state using a linear Gaussian state space models (GP-SSMs) [34], [35] and variational
state space model (LGSSM). To allow for non-linear dynamics, approximations [36], provide an alternative non-parametric
the KVAE uses a recognition model to produce time-varying approach to forecasting non-linear state space models – mod-
parameters for the LGSSM, weighting a set of K constant elling hidden states and observation dynamics using GPs.
parameters using weights generated by a neural network. Deep While they have similar benefits to Bayes filters (i.e. predictive
State Space Models (DSSM) [26] investigate a similar approach uncertainties, natural multistep prediction etc.), inference at
within the context of time series prediction, using an RNN to each time step has at least an O(T ) complexity in the number
generate parameters of the LGSSM at each time step. While of past observations – either via sparse GP approximations or
the LGSSM components do allow for the application of the Kalman filter formulations [37]. In contrast, the RNF updates its
Kalman filter, we note that updates to the time-varying weights belief state at each time point only with the latest observations
from the RNN once again blend the prediction and update steps and input, making it suitable for real-time prediction on high-
– making the separation of Bayes filter steps and generalisation frequency datasets.
to other tasks non-trivial. On the other hand, the RNF naturally RNNs for Multistep Prediction: Customised sequence-
supports simple extensions (e.g. multistep prediction) similarly to-sequence architectures have been explored in [38], [39]
to other Bayes filter – due to the close alignment of the RNF for multistep time series prediction, typically predefining
architecture with the Bayes filter steps and the use of decoupled the forecast horizon, and using computationally expensive
customised training procedures to improve performance. In IV. R ECURRENT N EURAL F ILTER
contrast, the RNF does not require the use of a separate training Recurrent Neural Filters use a series of encoders and
procedure for multistep predictions – hence reducing the decoders to learn independent representations for the Bayesian
computational overhead – and does not require the specification filtering steps. We investigate two RNF variants as described
of a fixed forecast horizon. below, based on Equations (7) and (8) respectively.
III. P ROBLEM D EFINITION Variational Autoencoder Form (VRNF) Firstly, we
T capture the belief state of Equation (4) using a recurrent VAE-
Let yt = [yt (1), . . . , yt (O)] be a vector of ob- based architecture. At run time, samples of x are generated
t
servations, driven by a set of stochastic hidden states from the encoder – approximating the integral of Equation (7)
xt = [xt (1), . . . , xt (J)]T and exogenous inputs ut = to compute the predictive distribution of y .
T t
[ut (1), . . . , ut (I)] . We consider non-linear state space models Standard Autoencoder Form (RNF)1 Much recent
of the following form: work has demonstrated the sensitivity of VAE performance to

yt ∼ Π f (xt ) (1) the choice of prior distribution, with suboptimal priors either
 having an “over-regularising” effect on the loss function during
xt ∼ N µ(xt−1 , ut ), Σ(xt−1 , ut ) (2) training [40]–[42], or leading to posterior collapse [43]. As
where Π is an arbitrary distribution parametrised by a non- such, we also implement an autoregressive version of the RNF
linear function f (xt ), with µ(·) and Σ(·) being mean and based on Equation (8) – directly feeding encoder latent states
covariance functions respectively. into the common emission decoder.
Bayes filters allow for efficient inference through the use A general architecture diagram for both forms is shown in
of a belief state, i.e. a posterior distribution of hidden states Figure 1, with the main differences encapsulated within z(s)
given past observations y1:t = {y1 , . . . , yt } and inputs u1:t = (see Section IV-A).
{u1 , . . . , ut }. This is achieved through the maintenance of a A. Network Architecture
set sufficient statistics θt – e.g. means and covariances θt ∈
{µt , Σt } – which compactly summarise the historical data: First, let st be a latent state that maps to sufficient statistics
θt , which are obtained as outputs from our recognition model.
p(xt |y1:t , u1:t ) = bel(xt ; θt ) (3) Per Equations (5) and (6), inference at run-time is controlled
= N (xt ; µt , Σt ) (4) through the recursive update of st , using a series of Long
Short-Term Memory (LSTM) [44] encoders with exponential
where bel(.) is a probability distribution function for the belief linear unit (ELU) activations [45].
state. Encoder To directly estimate the impact of exogenous
For filters such as the Kalman filter – and non-linear inputs on the belief state, the prediction step, Equation (5),
variants like the unscented Kalman filter [9] – θt is recursively is divided into two parts with separate LSTM units φx (·)
updated through a series of prediction and update steps which and φu (·). We use ht to represent all required memory
take the general form: components – i.e. both output vector and cell state for the
standard LSTM – with st being the output of the cell. A third
Prediction (State Transition): LSTM cell φy (·) is then used for the update step, Equation
(6), with the full set of equations below.
θ̃t = φu (θt−1 , ut ) (5)
Update (Error Correction): Prediction:
  h 0 0i
θt = φy θ̃t , yt (6) Propagation s̃t , h̃t = φx (ht−1 ) (9)
h i 0

where φu (·) and φy (·) are non-linear deterministic functions. Input Dynamics s̃t , h̃t = φu (h̃t , ut ) (10)
Forecasts can then be computed using one-step ahead predictive Update:
distributions:
Z   Error Correction [st , ht ] = φy (h̃t , yt ) (11)
p(yt |y1:t−1 , u1:t ) = p(yt |xt ) bel xt ; θ̃t dxt . (7)
For the variational RNF, hidden state variable xt is modelled
In certain cases – e.g. with the Kalman filter – the predictive as multivariate Gaussian, given by:
distribution can also be directly parameterised using analytical xt ∼ N (m(s̃t ), V (s̃t )) (12)
functions g(.) for belief state statistics :
  m(s̃t ) = Wm s̃t + bm (13)
p(yt |y1:t−1 , u1:t ) = p yt | g(θ̃t ) . (8) V (s̃t ) = diag(σ(s̃t ) σ(s̃t )) (14)
When observations are continuous, such as in standard linear σ(s̃t ) = Softplus(Wσ s̃t + bσ ), (15)
Gaussian state space models, yt can be modelled  using a 1 An open-source implementation of the standard RNF can be found at:
Normal distribution – i.e. yt ∼ N gµ (θ̃t ), gΣ (θ̃t ) . https://github.com/sjblim/rnf-ijcnn-2020
Fig. 1. RNF Network Architecture

where W(·) , b(·) are the weights/biases of each layer, and


is an element-wise (Hadamard) product.

For the standard RNF, the encoder state s̃t is directly fed
into the emission decoder leading to the following forms for
z̃t = z(s̃t ):
zVRNF (s̃t ) = xt , zRNF (s̃t ) = s̃t . (16)
While the connection to non-linear state-space models facilitates
our interpretation of zRNF (s̃t ), we note that the standard RNF
no longer relies on an explicit generative model for the latent
state xt . This potentially allows the standard RNF to learn
more complex update rules for non-Gaussian latent states.
Decoder Given an encoder output z̃t , we use a multi- Fig. 2. RNF Configuration with Missing Data
layer perceptron to model the emission function f (·):
f (z̃t ) = Wz2 ELU(Wz1 z̃t + bz1 ) + bz2 . (17)
Bayes filters, we decouple the RNF stages at run-time based on
the availability of inputs for prediction – allowing it to handle
This allows us to handle both continuous or binary observa-
applications involving missing data or multistep forecasting.
tions using the output models below:
  Figure 2 demonstrates how the RNF stages can be combined
continuous to accommodate missing data, noting that the colour scheme of
yt ∼ N fµ (z̃t ) , Γ(z̃t ) , (18)
the encoders/decoders shown matches that of Figure 1. From
  the schematic, the propagation encoder – which is responsible
yt binary ∼ Bernoulli Sigmoid(f (z̃t )) . (19) for changes to the belief state due to time evolution – is always
applied, with the input dynamics and error correction encoders
where Γ(z̃t ) = diag (gσ (z̃t )) is a time-dependent diagonal only used when inputs or observations are observed respectively.
covariance matrix, and gσ (z̃t ) = Softplus(fσ (z̃t )). Where inputs are available, the emission decoder is applied to
the input dynamics encoder to generate predictions at each step.
For yt continuous , the weights Wz1 , bz1 are shared between Failing that, the decoder is applied to the propagation encoder
fµ (·) and fσ (·) – i.e. both observation means and covariances alone.Multistep forecasts can also be treated as predictions in
are generated from the same encoder hidden layer. the absence of inputs or observations, with the encoders used
to project the belief state in a similar fashion to missing data.
B. Handling Missing Data and Multistep Prediction
From the above, we can see that each encoder learns how V. T RAINING M ETHODOLOGY
specific inputs (i.e. time evolution, exogenous inputs and the Considering the joint probability for a trajectory of length
target) modify the belief state. As such, in a similar fashion to T , we train the standard RNF by minimising the negative log-
likelihood of the observations. For continuous observations, steps described in Section IV-A. Encoders are then trained
this involves Gaussian likelihoods from Equation (18): jointly using the combined loss function below:
T
X Lcombined (ω, y1:T , u1:T )
LRNF (ω, s̃1:T ) = − log p(yt |s̃t ), (20) Additional Regularisation Terms
t=1 z }| {
J 
0

1 X = L(ω, s̃1:T ) + αx L(ω, s̃1:T ) + αy L(ω, s1:T ) . (23)


log p(yt |s̃t ) = − log(2πgσ (j, s̃t )2 ) | {z } | {z } | {z }
2 j=1
Input Dynamics Propagation Error Correction

2 As such, the additional stages can be interpreted as regular-


yt (j) − fµ (j, s̃t )
+ , (21) isation terms for the VRNF or RNF loss functions – which
gσ (j, s̃t )
we weight by constants αx and αy to control the relative
where ω are the weights of the deep neural network, fµ (j, z̃t ) importance of the intermediate encoder representations. For our
is the j-th element of fµ (z̃t ), and gσ (j, z̃t ) the j-th element main experiments, we place equal importance on all encoders,
of gσ (z̃t ). i.e. αx = αy = 1, to facilitate the subsequent separation of
For the VRNF, we adopt the Stochastic Gradient Variational stages for multistep prediction – with a full ablation analysis
Bayes (SGVB) estimator of [46] for our VAE evidence lower performed to assess the impact of various α settings during
bound, expressing our loss function as: training.
T  L  Furthermore, the error correction component φy (·) can also
X 1X
LVRNF (ω, s̃1:T ) =
(i)
log p(yt |xt (s̃t )) be interpreted as a pure auto-encoding step for the latest
t=1
L i=1 observation, recovering distributions p(yt |xt ) based on filtered
 distributions of p(xt |y1:t , u1:t ). Given that all stages share
−KL q(x1:T ) || p(x1:T ) , (22)
the same emissions decoder, this obliges the network to learn
(i)
where L is the number of samples used for calibration, xk (s̃k ) representations for st that are able to reconstruct the current
is the i-th sample given the latent state s̃k , and KL ·) is the observation when it is available.
KL divergence term defined based on the priors in Section Introducing Artificial Missingness Next, to encourage
V-A. the clean separation of encoder stages for generalisation to
other tasks, we break dependencies between the encoders by
A. VAE Priors for VRNF introducing artificial missingness into the dataset – randomly
Using the generative model for xt in Equation (2), we dropping out inputs and observations with a missingness rate
consider the definition of two priors for the VRNF, as described r. As encoders are only applied where data is present (see
briefly below. A full definition can be found in Appendix2 Figure 2), input dynamics and error correction encoders are
A, which also includes derivations for the KL term used in hence randomly skipped over during training – encouraging
LVRNF (ω, s̃1:T ). the encoder to perform regardless of which encoder stage
Kalman Filter Prior (VRNF-KF) Considering a linear preceded it. This also bears a resemblance to input dropout
Gaussian state space form for Equations (1) and (2), we can during training, which we apply to competing benchmarks to
apply the Kalman filtering equations to obtained distributions ensure comparability.
for xt at each time step (e.g. p(xt |y1:t , u1:t )). This also lets
us analytically define how the means and covariances of the VI. P ERFORMANCE E VALUATION
belief state change with different sets of information – aligning A. Time Series Datasets
the VRNF’s encoder stages with the filtering equations. We conduct a series of tests on 3 real-world time series
Neural Network Prior (VRNF-NN) In the spirit of the datasets to evaluate performance:
DKF [13], the analytical equations from the Kalman filter
1) Electricity: The public UCI Individual Household Elec-
prior above can also be approximated using simple multilayer
tric Power Consumption Data [47]
perceptrons. This would also allow belief state updates to
2) Volatility: A 30-min realised variance [48] dataset for
accommodate non-linear states space dynamics, making it fa
30 different stock indices
less restrictive prior model.
3) Quote: A high-frequency market microstructure dataset
B. Encouraging Decoupled Representations containing Barclays Level-1 quote data from Thomson
Reuters Tick History (TRTH)
Combined Encoder Training To improve representation
learning, the RNF is trained in a “multi-task” fashion – with Details on input/output features and preprocessing are fully
each intermediate stage trained to encode latent states for output documented in Appendix C for reference.
distributions. This is achieved by applying the same emissions
B. Conduct of Experiment
decoder to all encoders during training as indicated in Figure 1,
with each encoder/decoder aligned with the Bayesian filtering Benchmarks: We compare the VRNF-KF, VRNF-NN
and standard RNF against a range of autoregressive and RVAE
2 URL for full paper with appendix: https://arxiv.org/abs/1901.08096 benchmarks – including the DeepAR Model [30], Deep State
TABLE I
N ORMALISED MSE S F OR O NE -S TEP -A HEAD P REDICTIONS

DeepAR DSSM VRNN DKF VRNF-KF VRNF-NN RNF


Electricity 0.908 1.000 2.002 0.867 0.861 0.852 0.780*
Volatility 3.956 1.000 0.991 0.982 1.914 1.284 0.976*
Quote 0.998 1.000 3.733 1.001 1.000 1.001 0.997*

TABLE II
C OVERAGE P ROBABILITY O F O NE -S TEP -A HEAD 90% P REDICTION I NTERVAL

DeepAR DSSM VRNN DKF VRNF-KF VRNF-NN RNF


Electricity 0.966 0.964 0.981 0.965 0.320 0.271 0.961*
Volatility 0.997* 0.999 1.000 1.000 1.000 1.000 1.000
Quote 0.997 0.991 0.005 0.998 0.924 * 0.992 0.997

TABLE III
N ORMALISED MSE S F OR M ULTISTEP P REDICTIONS W ITH B OTH U NKNOWN AND K NOWN I NPUTS

Input Type Dataset τ = DeepAR DSSM VRNN DKF VRNF-KF VRNF-NN RNF
Unknown Inputs Electricity 5 3.260 3.308 3.080 2.946 2.607 2.015 1.996*
10 4.559 4.771 4.533 4.419 5.467 3.506* 3.587
20 6.555 6.827 6.620 6.524 9.817 5.449* 6.098
Volatility 5 3.945 1.628 0.994 0.986 4.084 1.020 0.967*
10 3.960 1.639 0.994 0.985 4.140 1.017 0.967*
20 3.955 1.641 0.993 0.983 4.163 1.014 0.966*
Quote 5 1.000 1.000 1.001 1.000 1.002 0.999 0.998*
10 1.000 1.001 1.000 1.001 1.009 1.002 1.000*
20 1.000 1.001 1.003 1.001 1.488 1.003 1.000*
Known Inputs Electricity 5 3.260 3.199 3.045 1.073 1.112 0.877 0.813*
10 4.559 4.382 4.470 1.008 1.180 0.882 0.831*
20 6.555 6.174 6.514 0.989 1.209 0.884 0.846*
Volatility 5 3.988 1.615 0.994 0.986 2.645 1.009 0.981*
10 3.992 1.620 0.994 0.985 2.652 1.009 0.981*
20 3.991 1.627 0.993 0.984 2.652 1.008 0.980*
Quote 5 1.000 1.000 0.998 1.000 1.001 1.000 0.997*
10 1.000 1.000 0.999 1.000 1.003 1.000 0.998*
20 1.000 1.000 1.003 1.000 1.009 1.000 0.999*

Space Model (DSSM) [26], Variational RNN (VRNN) [12], squared error up to the maximum prediction horizon (τ ). As
and Deep Kalman Filter (DKF) [13]. observations are 1D continuous variables for all our datasets,
For multistep prediction, we consider two potential use cases we evaluate uncertainty estimates using the prediction interval
for exogenous inputs: (i) when future inputs are unknown coverage probability (PICP) of a 90% prediction interval,
beforehand and imputed using their last observed values, and defined as:
(ii) when inputs are known in advance and used as given. When 1X
T

models require observations of yt as inputs, we recursively PICP = ct , (24)


T t=1
feed outputs from the network as inputs at the next time step.
These tweaks allow the benchmarks to be used for multistep (
prediction without modifying network architectures. For the 1, if ψ(0.05, t) < yt < ψ(0.95, t)
ct = (25)
RNF, we consider the application of the propagation encoder 0, otherwise
alone for the former case, and a combination of the propagation
where ψ(0.05, t) is the 5th percentile of samples from
and input dynamics encoder for the latter – as detailed in
N (f (xt ), Γ).
Section IV-B.
Training Details: Please refer to Appendix B for full
Metrics: To determine the accuracy of forecasts, we details of network calibration.
evaluate the mean-squared-error (MSE) for single-step and
multistep predictions, normalising each using the MSE of the C. Results and Discussion
one-step-ahead forecast for the best autoregressive model (i.e. On the whole, the standard RNF demonstrates the best overall
the DSSM). For multistep forecasts, we measure the average performance – improving MSEs in general for one-step-ahead
TABLE IV
N ORMALISED MSE S FOR A BLATION S TUDIES

Electricity Volatility Quote


τ = 1 5 10 20 1 5 10 20 1 5 10 20
Unknown Inputs RNF - 1.996* 3.587* 6.098* - 0.967* 0.967* 0.966* - 0.998* 1.000* 1.000*
RNF-NS - 2.801 13.409 45.625 - 1.006 1.006 1.005 - 1.137 1.294 1.260
RNF-IO - 14.047 14.803 15.414 - 1.377 1.458 1.494 - 1.029 1.042 1.045
Known Inputs RNF 0.780 0.813 0.831* 0.846* 0.976* 0.981* 0.981* 0.980* 0.997* 0.997* 0.998* 0.999*
RNF-NS 0.828 0.948 0.997 1.042 0.979 0.983 0.983 0.982 1.001 1.003 1.019 1.026
RNF-IO 0.770* 0.809* 0.873 0.918 1.012 1.016 1.016 1.015 1.020 1.015 1.023 1.030

and multistep prediction. From the one-step-ahead MSEs in VRNF-KF’s PICP to the expected 90% on the Quote data.
Table I, the RNF improves forecasting accuracy by 19.6% As such, the autoregressive form of standard RNF leads to
on average across all datasets and benchmarks. These results more reliable performance from both a prediction accuracy and
are also echoed for multistep predictions in Table III, with uncertainty perspective – doing away with the need to define
the RNF beating the majority of baselines for all horizons a prior for xt .
and datasets. The only exception is the slight out-performance
of another RNF variant (the VRNF-NN) on the Electricity VII. C ONCLUSIONS
dataset with unknown inputs – possibly due to the adoption In this paper, we introduce a novel recurrent autoencoder
of a suitable prior for this specific dataset – with the standard architecture, which we call the Recurrent Neural Filter (RNF),
RNF coming in a close second. The PICP results of Table to learn decoupled representations for the Bayesian filtering
II also show that performance is achieved without sacrificing steps – consisting of separate encoders for state propagation,
the quality of uncertainty estimates, with the RNF outputting input and error correction dynamics, and a common decoder to
similar uncertainty intervals compared to other deep generative model emission. Based on experiments with three real-world
and autoregressive models. On the whole, this demonstrates the time series datasets, the direct benefits of the architecture can
benefits of the proposed training approach for the RNF, which be seen from the improvements in one-step-ahead predictive
encourages decoupled representations using regularisation terms performance, while maintaining comparable uncertainty esti-
and skip training. mates to benchmarks. Due to its modular structure and close
To measure the benefits of the skip-training approach and alignment with Bayesian filtering steps, we also show the
proposed regularisation terms, we also perform a simple abla- potential to generalise the RNF to similar predictive tasks – as
tion study and train the RNF without the proposed components. seen from improvements in multistep prediction using extracted
Table IV shows the normalised MSEs for the ablation studies, state transition encoders.
with one-step and multistep forecasts combined into the same
table. Specifically, we test the RNF with no skip training R EFERENCES
in RNF-NS, and the RNF with only input dynamics outputs
[1] J. V. Candy, Bayesian Signal Processing: Classical, Modern and Particle
included in the loss function (i.e. αx = αy = 0) in RNF-IO. Filtering Methods. New York, NY, USA: Wiley-Interscience, 2009.
As inputs are always known for one-step-ahead predictions, [2] A. J. Haug, Bayesian estimation and tracking: a practical guide.
normalised MSEs for τ = 1 are omitted for unknown inputs. In Hoboken, NJ: John Wiley & Sons, 2012.
[3] T. D. Barfoot, State Estimation for Robotics. New York, NY, USA:
general, the inclusion of both skip training and regularisation Cambridge University Press, 2017.
terms improves forecasting performance, particularly in the case [4] H. Ghosh, B. Gurung, and Prajneshu, “Kalman filter-based modelling and
of longer-horizon predictions. We observe this from the MSE forecasting of stochastic volatility with threshold,” Journal of Applied
Statistics, vol. 42, no. 3, pp. 492–507, 2015.
improvements for all but short-term (τ ∈ {1, 5}) predictions [5] R. Sukkar, E. Katz, Y. Zhang, D. Raunig, and B. T. Wyman, “Disease
for known inputs, where the RNF-IO. However, the importance progression modeling using hidden Markov models,” in 2012 Annual
of both skip-training and regularisation can be seen from the International Conference of the IEEE Engineering in Medicine and
Biology Society, Aug 2012, pp. 2845–2848.
large multistep MSEs of both the RNF-NS and RNF-IO on [6] S. Sarkka, Bayesian Filtering and Smoothing. New York, NY, USA:
the Electricity dataset with unknown inputs– which results Cambridge University Press, 2013.
from error propagation when the input dynamics encoder is [7] S. Thrun, W. Burgard, and D. Fox, Probabilistic Robotics (Intelligent
Robotics and Autonomous Agents). The MIT Press, 2005.
removed. [8] R. E. Kalman, “A new approach to linear filtering and prediction
As mentioned in Section IV, the challenges of prior selection problems,” Transactions of the ASME–Journal of Basic Engineering,
for VAE-based methods can be seen from the PICPs in Table II – vol. 82, no. Series D, pp. 35–45, 1960.
[9] S. J. Julier and J. K. Uhlmann, “New extension of the Kalman filter to
with small PICPs for VRNF models indicative of miscalibrated nonlinear systems,” vol. 3068, 1997.
distributions in the Electricity data, and the poor MSEs and [10] A. Harvey, Forecasting, Structural Time Series Models and the Kalman
PICPs for the VRNN indicative of posterior collapse on the Filter. Cambridge University Press, 1991.
[11] A. C. Harvey and R. G. Pierse, “Estimating missing observations in
Quote data. However, this can also be beneficial when applied economic time series,” Journal of the American Statistical Association,
to appropriate datasets – as seen from the closeness of the vol. 79, no. 385, pp. 125–131, 1984.
[12] J. Chung, K. Kastner, L. Dinh, K. Goel, A. C. Courville, and Y. Bengio, [33] A. Venkatraman, N. Rhinehart, W. Sun, L. Pinto, M. Hebert, B. Boots,
“A recurrent latent variable model for sequential data,” in Advances in K. Kitani, and J. Bagnell, “Predictive-state decoders: Encoding the future
Neural Information Processing Systems 28 (NIPS 2016), 2015. into recurrent networks,” in Advances in Neural Information Processing
[13] R. G. Krishnan, U. Shalit, and D. Sontag, “Deep Kalman Filters,” ArXiv Systems 30 (NIPS 2017), 2017.
e-prints, 2015. [Online]. Available: http://arxiv.org/abs/1511.05121 [34] R. Turner, M. Deisenroth, and C. Rasmussen, “State-space inference
[14] R. G. Krishnan, U. Shalit, and D. Sontag, “Structured inference networks and learning with Gaussian processes,” in Proceedings of the Thirteenth
for nonlinear state space models,” in Proceedings of the Thirty-First International Conference on Artificial Intelligence and Statistics (AISTATS
AAAI Conference on Artificial Intelligence (AAAI 2017), 2017. 2010), 2010, pp. 868–875.
[15] M. Karl, M. Soelch, J. Bayer, and P. van der Smagt, “Deep variational [35] H. Nickisch, A. Solin, and A. Grigorevskiy, “State space Gaussian
Bayes filters: unsupervised learning of state space models from raw data,” processes with non-Gaussian likelihood,” in Proceedings of the 35th
in International Conference on Learning Representations (ICLR 2017), International Conference on Machine Learning (ICML 2018), 2018, pp.
2017. 3789–3798.
[16] S. Narayanaswamy, T. B. Paige, J.-W. van de Meent, A. Desmaison, [36] A. Doerr, C. Daniel, M. Schiegg, N.-T. Duy, S. Schaal, M. Toussaint, and
N. Goodman, P. Kohli, F. Wood, and P. Torr, “Learning disentangled rep- T. Sebastian, “Probabilistic recurrent state-space models,” in Proceedings
resentations with semi-supervised deep generative models,” in Advances of the 35th International Conference on Machine Learning (ICML 2018),
in Neural Information Processing Systems 30 (NIPS 2017), 2017, pp. 2018.
5925–5935. [37] S. Sarkka, A. Solin, and J. Hartikainen, “Spatiotemporal learning via
[17] H. Kim and A. Mnih, “Disentangling by factorising,” in Proceedings of infinite-dimensional Bayesian filtering and smoothing: A look at Gaussian
the 35th International Conference on Machine Learning (ICML 2018), process regression through Kalman filtering,” IEEE Signal Processing
2018. Magazine, vol. 30, no. 4, pp. 51–61, July 2013.
[18] G. Parascandolo, N. Kilbertus, M. Rojas-Carulla, and B. Schölkopf, [38] B. Pérez Orozco, G. Abbati, and S. Roberts, “MOrdReD: Memory-
“Learning independent causal mechanisms,” in Proceedings of the 35th based Ordinal Regression Deep Neural Networks for Time Series
International Conference on Machine Learning (ICML 2018), 2018. Forecasting,” CoRR, vol. arXiv:1803.09704, 2018. [Online]. Available:
[19] V. Thomas, E. Bengio, W. Fedus, J. Pondard, P. Beaudoin, H. Larochelle, http://arxiv.org/abs/1803.09704
J. Pineau, D. Precup, and Y. Bengio, “Disentangling the independently [39] R. Wen and K. T. B. M. Narayanaswamy, “A multi-horizon quantile
controllable factors of variation by interacting with the world,” in NIPS recurrent forecaster,” in NIPS 2017 Time Series Workshop, 2017.
2017 Workshop on Learning Disentangled Representations, 2018. [40] H. Takahashi, T. Iwata, Y. Yamanaka, M. Yamada, and S. Yagi,
[20] F. Locatello, S. Bauer, M. Lucic, S. Gelly, B. Schölkopf, and “Variational autoencoder with implicit optimal priors,” CoRR, vol.
O. Bachem, “Challenging common assumptions in the unsupervised abs/1809.05284, 2018. [Online]. Available: https://arxiv.org/abs/1809.
learning of disentangled representations,” CoRR, vol. abs/1811.12359, 05284
2018. [Online]. Available: http://arxiv.org/abs/1811.12359 [41] Tomczak and Welling, “VAE with a VampPrior,” in Proceedings of
[21] I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, the 21st Internation Conference on Artificial Intelligence and Statistics
S. Mohamed, and A. Lerchner, “beta-VAE: Learning basic visual concepts (AISTATS), 2018.
with a constrained variational framework,” in International Conference [42] S. R. Bowman, L. Vilnis, O. Vinyals, A. M. Dai, R. Józefowicz, and
on Learning Representations (ICLR 2017), 2017. S. Bengio, “Generating sentences from a continuous space,” CoRR,
[22] B. M. Lake, T. D. Ullman, J. B. Tenenbaum, and S. J. Gershman, vol. abs/1511.06349, 2015. [Online]. Available: http://arxiv.org/abs/1511.
“Building machines that learn and think like people,” Behavioral and 06349
Brain Sciences, vol. 40, p. e253, 2017. [43] A. van den Oord, O. Vinyals, and k. kavukcuoglu, “Neural discrete
[23] M. Johnson, D. K. Duvenaud, A. Wiltschko, R. P. Adams, and S. R. representation learning,” in Advances in Neural Information Processing
Datta, “Composing graphical models with neural networks for structured Systems 30 (NIPS), 2017, pp. 6306–6315.
representations and fast inference,” in Advances in Neural Information [44] S. Hochreiter and J. Schmidhuber, “Long short-term memory,” Neural
Processing Systems 29 (NIPS 2016), 2016. Computation, vol. 9, no. 8, pp. 1735–1780, Nov. 1997.
[24] L. Ralaivola and F. d’Alche Buc, “Time series filtering, smoothing and [45] D.-A. Clevert, T. Unterthiner, and S. Hochreiter, “Fast and accurate deep
learning using the kernel Kalman filter,” in Proceedings. 2005 IEEE network learning by exponential linear units (ELUs),” in International
International Joint Conference on Neural Networks, vol. 3, July 2005, Conference on Learning Representations (ICLR 2016), 2016.
pp. 1449–1454. [46] D. P. Kingma and M. Welling, “Auto-encoding variational Bayes,” in
[25] M. Fraccaro, S. Kamronn, U. Paquet, and O. Winther, “A disentangled International Conference on Learning Representations (ICLR 2014),
recognition and nonlinear dynamics model for unsupervised learning,” 2014.
in Advances in Neural Information Processing Systems 30 (NIPS 2017), [47] D. Dheeru and E. Karra Taniskidou, “UCI machine learning repository
2017. – individual household electric power consumption data set,” 2017.
[26] S. S. Rangapuram, M. W. Seeger, J. Gasthaus, L. Stella, Y. Wang, and [Online]. Available: https://archive.ics.uci.edu/ml/datasets
T. Januschowski, “Deep state space models for time series forecasting,” [48] T. Andersen, T. Bollerslev, F. Diebold, and P. Labys, “Modeling and
in Advances in Neural Information Processing Systems 31 (NeurIPS forecasting realized volatility,” Econometrica, vol. 71, no. 2, pp. 579–625,
2018), 2018. 2003.
[27] A. van den Oord, S. Dieleman, H. Zen, K. Simonyan, O. Vinyals, [49] A. Yadav, P. Awasthi, N. Naik, and M. R. Ananthasayanam, “A constant
A. Graves, N. Kalchbrenner, A. W. Senior, and K. Kavukcuoglu, gain kalman filter approach to track maneuvering targets,” in 2013 IEEE
“WaveNet: A generative model for raw audio,” CoRR, vol. abs/1609.03499, International Conference on Control Applications (CCA), 2013.
2016. [50] T. G. Andersen and T. Bollerslev, “Intraday periodicity and volatility
[28] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, persistence in financial markets,” Journal of Empirical Finance, vol. 4,
L. u. Kaiser, and I. Polosukhin, “Attention is all you need,” in Advances no. 2, pp. 115 – 158, 1997.
in Neural Information Processing Systems 30 (NIPS 2017), 2017. [51] A. Todd, R. Hayes, P. Beling, and W. Scherer, “Micro-price trading in
[29] Z. Dai, Z. Yang, Y. Yang, J. G. Carbonell, Q. V. Le, and an order-driven market,” in 2014 IEEE Conference on Computational
R. Salakhutdinov, “Transformer-XL: Attentive language models beyond Intelligence for Financial Engineering Economics (CIFEr), March 2014,
a fixed-length context,” CoRR, vol. abs/1901.02860, 2019. [Online]. pp. 294–297.
Available: http://arxiv.org/abs/1901.02860 [52] lvaro Cartea, R. Donnelly, and S. Jaimungal, “Enhancing trading strategies
[30] V. Flunkert, D. Salinas, and J. Gasthaus, “Deepar: Probabilistic forecasting with order book signals,” Applied Mathematical Finance, vol. 25, no. 1,
with autoregressive recurrent networks,” CoRR, vol. abs/1704.04110, pp. 1–35, 2018.
2017. [Online]. Available: http://arxiv.org/abs/1704.04110
[31] K. Choromanski, C. Downey, and B. Boots, “Initialization matters:
Orthogonal predictive state recurrent neural networks,” in International
Conference on Learning Representations (ICLR 2018), 2018.
[32] C. Downey, A. Hefny, B. Boots, G. J. Gordon, and B. Li, “Predictive
state recurrent neural networks,” in Advances in Neural Information
Processing Systems 30 (NIPS 2017), 2017.
A PPENDIX Kalman Prior (VRNF-KF)
A. VRNF Priors and Derivation of KL Term
The use Kalman filter relies on the definition of a linear
Defining a prior distribution for the VRNF starts with the
Gaussian state space model, which we specify below:
specification of a model for the distribution of hidden state
xt , conditioned on the amount of available information at
yt = Hxt + et , (33)
each encoder to achieve alignment with the VRNF stages.
Per the generative model of Equation (2), we model xt as a xt = Axt−1 + But + t , (34)
multivariate normal distribution with a mean and covariance
that varies with time and the information present at each where H, A, B are constant matrices, and et ∼ N (0, R), t ∼
encoder, based on the notation below: N (0, Q) are noise terms with constant noise covariances R
and Q.
Propagation: a) Propagation: Assuming that inputs ut is unknown
0 0 at time t, predictive distributions can still be computed for
p(xt |y1:t−1 , u1:t−1 ) ∼ N (β̃t , ν̃t ), (26) the hidden state if we have a model for ut . In the simplest
Input Dynamics: case, this can be a standard normal distribution, i.e. ut ∼
N (c, D) – where c is a constant mean vector and D a constant
p(xt |y1:t−1 , u1:t ) ∼ N (β̃t , ν̃t ), (27) covariance matrix. Under this model, predictive distributions
Error Correction: can be computed as below:
0
p(xt |y1:t , u1:t ) ∼ N (βt , νt ). (28) β̃t = Aβt−1 + Bc
0
For the various priors defined in this section, we adopt the = Aβt−1 + c , (35)
use of diagonal covariance matrices for the inputs, defined as: 0
T
ν̃t = Aνt−1 A + BDB + Q T
0
νt = diag(γt γt ), (29) = Aνt−1 AT + Q , (36)
where γt ∈ RJ is a vector of standard deviation parameters. 0 0
with c , Q collapsing constant terms together into a single
This approximation helps to reduce the computational
parameters.
complexity associated with the matrix multiplications using full
covariance matrices, and the O(J 2 T ) memory requirements b) Input Dynamics: When inputs are known, the forecast-
from storing full covariances matrices for an RNN unrolled ing equations take on a similar form:
across T timesteps.
β̃t = Aβt−1 + But (37)
KL Divergence Term ν̃t = Aνt−1 AT + Q (38)

Considering the application of the input dynamics encoder Comparing this with the forecasting equations of the prop-
0
alone (i.e. αx = αy = 0), the KL divergence between agation step, we can express also the above as functions β̃t
0
independent conditional multivariate Gaussians at each time and ν̃t , i.e.:
step can be hence expressed analytically as: 0 0
 β̃t = β̃t + But − c , (39)
KLInput q(x1:T ) || p(x1:T ) (30) 0 0
 ν̃t = ν̃t − Q + Q. (40)
p(x1 )
= Eq(x1:T ) log
q(x1 |s̃1 ) c) Error Correction: Upon receipt of a new observation,

XT
p(xt |xt−1 , u1:t , y1:t−1 ) the Kalman filter computes a Kalman Gain Kt , using it to
+ log (31) correct the belief state as below:
t=2
q(xt |s̃t )
XT X J 
γ̃t (j) βt = (I − Kt H) β̃t − Kt Hyt , (41)
= log
t=1 j=1
σ(j, s̃t ) νt = (I − Kt H) ν̃t , (42)

σ(j, s̃t )2 + (m(j, s̃t ) − β̃t (j))2 1 where I is an  identity matrix and Kt =
+ − , (32) −1
2γ̃t (j)2 2 ν̃t H T H ν̃t H T + R .
where m(j, s̃t ), σ(j, s̃t ) are j-th elements of m(s̃t ), σ(s̃t ) as
defined in Equations (14) and (15) respectively. Approximations for Efficiency
The KL divergence terms are defined similarly for the
propagation and error correction encoders, using the means To avoid the complex memory and space requirements asso-
and standard deviations defined above. ciated with full matrix computations, we make the following
approximations in our Kalman Filter equations.
d) Constant Kalman Gain: Firstly, as noted in [49], approximate the equations described in the previous section,
Kalman gain values in stable filters usually tend towards a conditioning it on the previous active encoder stage, i.e.:
steady state value after a initial transient period. We hence fix Propagation:
the Kalman gain at a constant value, and collapse constant 0

coefficients in the error correction equations to give: β̃t = MLPβ̃0 (m(st−1 )) (51)
0
0 0 ν̃t = MLPν̃ 0 (V (st−1 )) (52)
βt = K β̃t − H yt , (43)
0
νt = K ν̃t , (44) Input Dynamics:
0
Where K = (I − KH) and H = KH.
0
β̃t = MLPβ̃ (m(s̃t ), ut ) (53)
e) Independent Hidden State Dimensions: Next, we ν̃t = MLPν̃ (V (s̃t ), ut ) (54)
assume that hidden state dimensions are independent of
one another, which effectively diagonalising state related Error Correction:
coefficients A = diag(a) and Q = diag(q). βt = MLPβ (m(st ), yt ) (55)
0 0
f) Diagonalising Q , K : Finally, to allow us to diagonal
covariance matrices throughout our equations, we also νt = MLPν (V (st ), yt ) (56)
0 0 0 0
diagonalise Q = diag(q ) and K = diag(k ). Similar to the state transition functions used in [13], this can
be interpreted as using MLPs to approximate the true Kalman
Prior Definition filter functions for linear datasets, while also permitting the
learning of more sophisticated non-linear models. All MLPs
Using the above definitions and approximations, the Kalman defined here use an ELU activation function for their hidden
filter prior can hence be expressed in vector form using the layer, fixing the hidden state size to be J. Furthermore, we use
equations below: linear output layers for β MLPs, while passing that of ν MLPs
through a softplus activation function to maintain positivity.
Propagation:
0 0 B. Training Procedure for RNF
β̃t = a m(st−1 ) + c , (45)
0 0 a) Training Details: During network calibration, trajec-
ν̃t = a V (st−1 ) a+q . (46) tories were partitioned into segments of 50 time steps each –
Input Dynamics: which were randomly combined to form minibatches during
training. Also, networks were trained for up to a maximum of
β̃t = a m(st−1 ) + But , (47) 100 epochs or convergence. For the electricity and volatility
ν̃t = a V (st−1 ) a + q. (48) datasets, 50 iterations of random search were performed, using
the grid found in Table V. 20 iterations of random search were
Error Correction: used for the quote dataset, as the significantly larger dataset
0 0 led to longer training times for a given set of hyperparameters.
βt = k m(s̃t ) − H yt , (49)
0
νt = k V (s̃t ). (50) TABLE V
R ANDOM S EARCH G RID FOR H YPERPARAMETER O PTIMISATION
All constant standard deviation are implemented as coef-
ficients wrapped in a softmax layer (e.g. a = softplus(φ))) Hyperparameter Ranges
to prevent the optimiser from converging on invalid negative Dropout Rate 0.0, 0.1, 0.2, 0.3, 0.4, 0.5
numbers. State Size 5, 10, 25, 50, 100, 150
Minibatch Size 256, 512, 1024
In addition, we note that the form input dynamics prior Learning Rate 0.0001, 0.001, 0.01, 0.1, 1.0
is not conditioned on the propagation encoder outputs, Max Gradient Norm 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0
although we could in theory express it terms of its statistics Missing Rate 0.25, 0.5, 0.75
(i.e. Equations (39) and (40)). This to avoid converging on
negative values for variances, which can be obtained from the b) State Sizes: To ensure that consistency across all
0
subtraction of positive constant Q , although we revisit this in models used, we constrain both the memory state of the
this form in the next section. RNN and the latent variable modelled to have the same
dimensionality – i.e. J = dim(st ) = dim(xt ) for the
Neural Network Prior (VRNF-NN) RNF. The exception is the DSSM, as the full covariance
matrix of the Kalman filter would result in a prohibitive
Despite the convenient tractable from of the Kalman filtering J 2 memory requirement if left unchecked. As such, we use
equations, this relies on the use of linear state space assumptions the constraint where both the RNN and the Kalman filter
which might not be suitable for complex datasets. As such, we to have the same memory capacity for the DSSM – i.e.
also consider the use of multilayer perceptrons (MLP(·)) to J = dim(st ) = dim(xt ) + dim(xt )2 .
c) Dropout Application: Across all benchmarks, dropout direction (sign) of the next liquidity taking order, and the price
was applied only onto the memory state of the RNNs (ht ) in changes immediately after the arrival of a liquidity-taking order.
the standard fashion and not to latent states xt . For the LSTM,
DeepAR Model and RNF, this corresponds to applying dropout Electricity
masks to the outputs and internal states of the network. For the
VRNN, DKF and DSSM, we apply dropout only to the inputs a) Data Processing: The full trajectory was segmented
of the network – in line with [14] to maintain comparability into 3 portions, with the earliest 60% of measurements for
to the encoder skipping in the VRNFs. training, the next 20% as a validation set, and the final 20%
d) Artificial Missingness: Encoder skipping is restricted as an independent test set – with results reported in Section
to only the VRNFs and standard RNF, controlled by the missing VI. All data sets were normalised to have zero mean and unit
rates defined above. standard deviations, with normalising constants computed using
e) Sample Generation: At prediction time, latent states the training set alone.
for the VRNN, DKF and VRNFs are sampled as per the b) Summary Statistics: A list of summary statistics can
standard VAE case – using L = 1 during training, L = 30 for be seen in Table VI.
our validation error and L = 100 for at test time. Predictions
from the DeepAR Model, DSSM and standard RNF, however, TABLE VI
were obtained directly from the mean estimates, with that of S UMMARY S TATISTICS FOR E LECTRICITY DATASET
the DSSM computed analytically using the Kalman filtering
Mean S.D. Min Max
equations. While this differs slightly from the original paper Active Power* 1.11 1.12 0.08 11.12
[26], it also leads to improvements in the performance DSSM Reactive Power 0.12 0.11 0.00 1.39
by avoiding sampling errors. Intensity 4.73 4.70 0.20 48.40
Voltage 240.32 3.33 223.49 252.72
Sub Metering 1 1.17 6.31 0.00 82.00
C. Description of Datasets Sub Metering 2 1.42 6.26 0.00 78.00
For the experiments in Section VI, we focus on the use of 3 Sub Metering 3 6.04 8.27 0.00 31.00
real-world time series datasets, each containing over a million
time steps per dataset. These use-cases help us evaluate Intraday Volatility
performance for scenarios in which real-time predictions with
RNNs are most beneficial – i.e. when the underlying dynamics c) Data Processing: From the 1-min index returns,
is highly non-linear and trajectories are long. realised variances were computed as:

Summary rk = ln pk − ln pk−1
Xt
Electricity: The UCI Individual Household Electric y(t, 30) = rk2 , (57)
Power Consumption Dataset [47] is a time series of 7 different k=t−30
power consumption metrics measured at 1-min intervals for a
single household between December 2006 and November 2010 where rk is the 1-min index return at time k, ln pk is the log
– coming to a total of 2,075,259 time steps over 4 years. In our price at k, and y(t, 30) is the 30-min realised variance at time
experiments, we treated active power as the main observation t.
of interest, taking the remainder to be exogenous inputs into Before computation, the data was cleaned by only consider-
the RNNs. ing prices during exchange hours to avoid spurious jumps. In
Intraday Volatility: We compute 30-min realised vari- addition, realised variances greater than 10 times the 200-step
ances [48] for a universe of 30 different stock indices – derived rolling standard deviation were removed and replaced by its
using 1-min index returns subsampled from Thomson Reuters previous value – so as to reduce the impact of outliers.
Tick History Level 1 (TRTH L1) quote data. On the whole, For the experiments in Section VI, data across all stock
the entire dataset contains 1,706,709 measurements across all indices were grouped together for training and testing – using
indices, with each trajectory spanning 17 years on average. data prior to 2014 for training, data between 2014-2016 for
Given the strong evidence for the intraday periodicity of returns validation and data from 2016 to 4 July 2018 for independent
volatility [50], we also include the time-of-day as an additional testing. Min-max normalisation was applied to the datasets,
exogenous input. with time normalised by the maximum trading window of each
High-Frequency Stock Quotes: This dataset consists exchange and realised variances by the max and min values
of extracted features from TRTH L1 stock quote data for of the training dataset.
Barclays (BARC.L) – specifically forecasting microprice d) Stock Index Identifiers (RICs):: AEX, AORD, BFX,
returns [51] using volume imbalance as an input predictor BSESN, BVLG, BVSP, DJI, FCHI, FTMIB, FTSE, GDAXI,
– comprising a total of 29,321,946 time steps between 03 GSPTSE, HSI, IBEX, IXIC, KS11, KSE, MXX, N225, NSEI,
January 2017 to 29 December 2017. From [52], volume OMXC20, OMXHPI, OMXSPI, OSEAX, RUT, SMSI, SPX,
imbalance in the limit order book is a good predictor of the SSEC, SSMI, STOXX50E
e) Summary Statistics: A table of summary statistics can
be found in Table VII and give an indication of the general
ranges of trajectories.

TABLE VII
S UMMARY S TATISTICS FOR VOLATILITY DATASET

Mean S. D. Min Max


Realised Variance* 0.0007 0.0017 0.0000 0.1013
Normalised Time 0.43 0.27 0.00 0.97

High-Frequency Stock Quotes

f) Input/Output Definitions: Microprice returns yt are


defined as:
Va (t)pb (t) + Vb (t)pa (t)
pt =
Va (t) + Vb (t)
pt − pt−1
yt =
pt−1
Where Vb (t) and Va (t) are the bid and ask volumes at time t
respectively, pb (t) and pa (t) are the bid/ask prices, and pt the
microprice.
Volume imbalance It is then defined as:
Vb (t) − Va (t)
It =
Vb (t) + Va (t)
g) Data Processing: From the raw Level 1 (best bid
and ask prices and volumes) data from TRTH, we isolate
measurements between 08.30 to 16.00 UK time, avoiding
the effects of opening and closing auctions in our forecasts.
Furthermore, microprice returns were also normalised using
an exponentially weighting moving standard deviation with
a half-life of 10,000 steps. We note that volume imbalance
by definition is restricted to be It ∈ [−1, 1], and hence does
not require additional normalisation. Finally, the data was
partitioned with training data from January to June, validation
data from June to September, and the remainder for independent
testing.
h) Summary Statistics: Basic statistics can be found in
Table VIII, and give an indication of the range of different
variables.

TABLE VIII
S UMMARY S TATISTICS FOR Q UOTE DATASET

Mean S. D. Min Max


Normalised Returns* 0.00 0.80 -117.72 117.13
Volume Imbalance 0.02 0.48 -1.00 1.00
Chapter 4

Incorporating Domain Knowledge


With Hybrid Models

Publications Included
• B. Lim, S. Zohren, S. Roberts. Enhancing Time Series Momentum Strategies
Using Deep Neural Networks. Journal of Financial Data Science (JFDS), 2019.
Weblink: https://arxiv.org/abs/1904.04912.

• B. Lim, M. van der Schaar. Disease-Atlas: Navigating Disease Trajectories


with Deep Learning. Proceedings of the Machine Learning for Healthcare Con-
ference (MLHC), 2018. Weblink: https://arxiv.org/abs/1803.10254.

62
Enhancing Time Series Momentum Strategies
Using Deep Neural Networks
Bryan Lim, Stefan Zohren, Stephen Roberts

Abstract—While time series momentum [1] is a well- such as demand forecasting [10], medicine [11]
studied phenomenon in finance, common strategies require and finance [12]. With the development of modern
the explicit definition of both a trend estimator and a architectures such as convolutional neural networks
position sizing rule. In this paper, we introduce Deep
(CNNs) and recurrent neural networks (RNNs) [13],
Momentum Networks – a hybrid approach which injects
deep learning based trading rules into the volatility scaling deep learning models have been favoured for their
framework of time series momentum. The model also ability to build representations of a given dataset [14]
simultaneously learns both trend estimation and position – capturing temporal dynamics and cross-sectional
sizing in a data-driven manner, with networks directly relationships in a purely data-driven manner. The
trained by optimising the Sharpe ratio of the signal. Back- adoption of deep neural networks has also been
testing on a portfolio of 88 continuous futures contracts, we
facilitated by powerful open-source frameworks such
demonstrate that the Sharpe-optimised LSTM improved
traditional methods by more than two times in the absence as TensorFlow [15] and PyTorch [16] – which
of transactions costs, and continue outperforming when use automatic differentiation to compute gradients
considering transaction costs up to 2-3 basis points. To for backpropagation without having to explicitly
account for more illiquid assets, we also propose a turnover derive them in advance. In turn, this flexibility
regularisation term which trains the network to factor in has allowed deep neural networks to go beyond
costs at run-time. standard classification and regression models. For
I. I NTRODUCTION instance, the creation of hybrid methods that combine
traditional time-series models with neural network
Momentum as a risk premium in finance has been
components have been observed to outperform pure
extensively documented in the academic literature,
methods in either category [17] – e.g. the exponential
with evidence of persistent abnormal returns demon-
smoothing RNN [18], autoregressive CNNs [19]
strated across a range of asset classes, prediction
and Kalman filter variants [20, 21] – while also
horizons and time periods [2, 3, 4]. Based on the
making outputs easier to interpret by practitioners.
philosophy that strong price trends have a tendency
Furthermore, these frameworks have also enabled
to persist, time series momentum strategies are
the development of new loss functions for training
typically designed to increase position sizes with
neural networks, such as adversarial loss functions
large directional moves and reduce positions at
in generative adversarial networks (GANs) [22].
other times. Although the intuition underpinning
While numerous papers have investigated the
the strategy is clear, specific implementation details
use of machine learning for financial time series
can vary widely between signals – with a plethora
prediction, they typically focus on casting the un-
of methods available to estimate the magnitude of
derlying prediction problem as a standard regression
price trends [5, 6, 4] and map them to actual traded
or classification task [23, 24, 25, 12, 26, 19, 27] –
positions [7, 8, 9].
In recent times, deep neural networks have been with regression models forecasting expected returns,
increasingly used for time series prediction, out- and classification models predicting the direction
performing traditional benchmarks in applications of future price movements. This approach, however,
could lead to suboptimal performance in the context
B. Lim, S. Zohren and S. Roberts are with the Depart- time-series momentum for several reasons. Firstly,
ment of Engineering Science and the Oxford-Man Institute sizing positions based on expected returns alone
of Quantitative Finance, University of Oxford, Oxford, United
Kingdom (email: [email protected], [email protected], does not take risk characteristics into account –
[email protected]). such as the volatility or skew of the predictive

1
returns distribution — which could inadvertently momentum strategies [9]. This consistency also helps
expose signals to large downside moves. This is when making comparisons to existing methods, and
particularly relevant as raw momentum strategies facilitates the interpretation of different components
without adequate risk adjustments, such as volatility of the overall signal by practitioners.
scaling [7], are susceptible to large crashes during
II. R ELATED W ORKS
periods of market panic [28, 29]. Furthermore, even
with volatility scaling – which leads to positively A. Classical Momentum Strategies
skewed returns distributions and long-option-like Momentum strategies are traditionally divided
behaviour [30, 31] – trend following strategies can into two categories – namely (multivariate) cross
place more losing trades than winning ones and still sectional momentum [35, 24] and (univariate) time
be profitable on the whole – as they size up only series momentum [1, 8]. Cross sectional momentum
into large but infrequent directional moves. As such, strategies focus on the relative performance of
[32] argue that the fraction of winning trades is a securities against each other, buying relative winners
meaningless metric of performance, given that it and selling relative losers. By ranking a universe
cannot be evaluated independently from the trading of stocks based on their past return and trading the
style of the strategy. Similarly, high classification top decile against the bottom decile, [35] find that
accuracies may not necessarily translate into positive securities that recently outperformed their peers over
strategy performance, as profitability also depends the past 3 to 12 months continue to outperform on
on the magnitude of returns in each class. This average over the next month. The performance of
is also echoed in betting strategies such as the cross sectional momentum has also been shown to
Kelly criterion [33], which requires both win/loss be stable across time [36], and across a variety of
probabilities and betting odds for optimal sizing markets and asset classes [4].
in binomial games. In light of the deficiencies of Time series momentum extends the idea to focus
standard supervised learning techniques, new loss on an asset’s own past returns, building portfolios
functions and training methods would need to be comprising all securities under consideration. This
explored for position sizing – accounting for trade- was initially proposed by [1], who describe a
offs between risk and reward. concrete strategy which uses volatility scaling and
In this paper, we introduce a novel class of trades positions based on the sign of returns over
hybrid models that combines deep learning-based the past year – demonstrating profitability across
trading signals with the volatility scaling framework 58 different liquid instruments individually over
used in time series momentum strategies [8, 1] – 25 years of data. Since then, numerous trading
which we refer to as the Deep Momentum Net- rules have been proposed – with various trend
works (DMNs). This improves existing methods estimation techniques and methods map them to
from several angles. Firstly, by using deep neural traded positions. For instance, [6] documents a wide
networks to directly generate trading signals, we range of linear and non-linear filters to measure
remove the need to manually specify both the trends and a statistic to test for its significance
trend estimator and position sizing methodology – – although methods to size positions with these
allowing them to be learnt directly using modern time estimates are not directly discussed. [8] adopt a
series prediction architectures. Secondly, by utilising similar approach to [1], regressing the log price
automatic differentiation in existing backpropagation over the past 12 months against time and using
frameworks, we explicitly optimise networks for the regression coefficient t-statistics to determine
risk-adjusted performance metrics, i.e. the Sharpe the direction of the traded position. While Sharpe
ratio [34], improving the risk profile of the signal on ratios were comparable between the two, t-statistic
the whole. Lastly, retaining a consistent framework based trend estimation led to a 66% reduction in
with other momentum strategies also allows us to portfolio turnover and consequently trading costs.
retain desirable attributes from previous works – More sophisticated trading rules are proposed in
specifically volatility scaling, which plays a critical [4] and [37], taking volatility-normalised moving
role in the positive performance of time series average convergence divergence (MACD) indicators

2
as inputs. Despite the diversity of options, few autoencoder and denoising autoencoder architectures,
comparisons have been made between the trading incorporating volatility scaling into their model
rules themselves, offering little clear evidence or as well. While the results with basic deep neural
intuitive reasoning to favour one rule over the next. networks are promising, they do not consider more
We hence propose the use of deep neural networks modern architectures for time series prediction, such
to generate these rules directly, avoiding the need for as the LSTM [39] and WaveNet [40] architectures
explicit specification. Training them based on risk- which we evaluate for the DMN. Moreover, to the
adjusted performance metrics, the networks hence best of our knowledge, our paper is the first to
learn optimal training rules directly from the data consider the use of deep learning within the context
itself. of time series momentum strategies – opening up
possibilities in an alternate class of signals.
B. Deep Learning in Finance Popularised by success of DeepMind’s AlphaGo
Machine learning has long been used for financial Zero [41], deep reinforcement learning (RL) has
time series prediction, with recent deep learning also gained much attention in recent times – prized
applications studying mid-price prediction using for its ability to recommend path-dependent actions
daily data [26], or using limit order book data in dynamic environments. RL is particularly of
in a high frequency trading setting [25, 12, 38]. interest within the context of optimal execution
While a variety of CNN and RNN models have and automated hedging [42, 43] for example, where
been proposed, they typically frame the forecasting actions taken can have an impact on future states
task as a classification problem, demonstrating the of the world (e.g. market impact). However, deep
improved accuracy of their method in predicting RL methods generally require a realistic simulation
the direction of the next price movement. Trading environment (for Q-learning or policy gradient meth-
rules are then manually defined in relation to class ods), or model of the world (for model-based RL)
probabilities – either by using thresholds on classi- to provide feedback to agents during training – both
fication probabilities to determine when to initiate of which are difficult to obtain in practice.
positions [26], or incorporating these thresholds into
the classification problem itself by dividing price III. S TRATEGY D EFINITION
movements into buy, hold and sell classes depending Adopting the terminology of [8], the combined
on magnitude [12, 38]. In addition to restricting the returns of a time series momentum (TSMOM)
universe of strategies to those which rely on high strategy can be expressed as below – characterised
accuracy, further gains might be made by learning by a trading rule or signal Xt ∈ [−1, 1]:
trading rules directly from the data and removing
Nt
the need for manual specification – both of which 1 X (i) σtgt (i)
T SM OM
rt,t+1 = Xt r . (1)
are addressed in our proposed method. Nt i=1 (i) t,t+1
σt
Deep learning regression methods have also been
considered in cross-sectional strategies [23, 24], Here rt,t+1
T SM OM
is the realised return of the strategy
ranking assets on the basis of expected returns over from day t to t + 1, Nt is the number of included
(i)
the next time period. Using a variety of linear, tree- assets at t, and rt,t+1 is the one-day return of asset i.
based and neural network models [23] demonstrate We set the annualised volatility target σtgt to be 15%
the outperformance of non-linear methods, with deep and scale asset returns with an ex-ante volatility
(i)
neural networks – specifically 3-layer multilayer estimate σt – computed using an exponentially
perceptrons (MLPs) – having the best out-of-sample weighted moving standard deviation with a 60-day
predictive R2 . Machine learning portfolios were (i)
span on rt,t+1 .
then built by ranking stocks on a monthly basis
using model predictions, with the best strategy A. Standard Trading Rules
coming from a 4-layer MLP that trades the top In traditional financial time series momentum
decile against the bottom decile of predictions. In strategies, the construction of a trading signal Xt
other works, [24] adopt a similar approach using is typically divided into two steps: 1) estimating

3
future trends based on past information, and 2) moves. This allows the signal to reduces positions
computing the actual positions to hold. We illustrate in instances where assets are overbought or oversold
(i)
this in this section using two examples from the – defined to be when |qt | is observed to be larger
academic literature [1, 4], which we also include as than 1.41 times its past year’s standard deviation.
benchmarks into our tests.
Exhibit 1: Position Sizing Function φ(y)
Moskowitz et al. 2012 [1]: In their original paper
on time series momentum, a simple trading rule is
adopted as below:
(i) (i)
Trend Estimation: Yt = rt−252,t (2)
(i) (i)
Position Sizing: Xt = sgn(Yt ) (3)
This broadly uses the past year’s returns as a
trend estimate for the next time step - taking a
maximum long position when the expected trend
(i)
is positive (i.e. sgn(rt−252,t )) and a maximum short
position when negative.

Baz et al. 2015 [4]: In practice, more sophisti-


(i) (i)
cated methods can be used to compute Yt and Xt Increasing the complexity even further, multiple
– such as the model of [4] described below: signals with different times-scales can also be aver-
aged to give a final position:
(i)
(i) qt 3
X
Trend Estimation: Yt = (4) (i) (i)
(i)
std(zt−252:t ) Ỹt = Yt (Sk , Lk ), (8)
k=1
(i) (i)
qt= MACD(i, t, S, L) / std(pt−63:t ) (5) where Y (i) (Sk , Lk ) is as per Equation (4) with
t
MACD(i, t, S, L) = m(i, S) − m(i, L). (6) explicitly defined short and long time-scales – using
(i)
Sk ∈ {8, 16, 32} and Lk ∈ {24, 48, 96} as defined
Here std(pt−63:t ) is the 63-day rolling standard in [4].
(i) (i) (i)
deviation of asset i prices pt−63:t = [pt−63 , . . . , pt ],
m(i, S) is the exponentially weighted moving aver- B. Machine Learning Extensions
age of asset i prices with a time-scale S that trans- As can be seen from Section III-A, many
lates into a half-life of HL = log(0.5)/ log(1 − S1 ). explicit design decisions are required to define a
The moving average crossover divergence (MACD) sophisticated time series momentum strategy. We
signal is defined in relation to a short and a long hence start by considering how machine learning
time-scale S and L respectively. methods can be used to learn these relationships
The volatility-normalised MACD signal hence directly from data – alleviating the need for manual
measures the strength of the trend, which is then specification.
translated in to a position size as below:
Standard Supervised Learning: In line with nu-
Position Sizing:
(i)
Xt =
(i)
φ(Yt ),(7) merous previous works (see Section II-B), we can
2
cast trend estimation as a standard regression or
where φ(y) = 0.894 . Plotting φ(y) in Exhibit 1, binary classification problem, with outputs:
y exp( −y )

(i)  
we
√ can see that positions are increased until |Yt | = (i) (i)
Trend Estimation: Yt = f ut ; θ , (9)
2 ≈ 1.41, before decreasing back to zero for larger

4
where f (·) is the output of the machine learning average return and the annualised Sharpe ratio via
model, which takes in a vector of input features the loss functions below:
(i)
ut and model parameters θ to generate predictions.
Taking volatility-normalised returns as targets, the Lreturns (θθ ) = −µR
following mean-squared error and binary cross- 1 X
= − R(i, t)
entropy losses can be used for training: M Ω
(i)
!2 1 X (i) σtgt (i)
1 X rt,t+1 = − X r (15)
Lreg (θθ ) =
(i)
Yt − (i) (10) M Ω t σt(i) t,t+1
M Ω σt √
   µR × 252
1 X (i) Lsharpe (θθ ) = − p P (16)
Lbinary (θθ ) = − I log Yt ( Ω R(i, t)2 ) /M − µ2R
M Ω
  where µR is the average return over Ω, and R(i, t)
(i)
+ (1 − I) log 1 − Yt ,(11) is the return captured by the trading rule for asset i
 (1) (1) (1)  at time t.
where Ω = Y1 , r1,2 /σt , ... ,
(N ) (N ) (N )  IV. D EEP M OMENTUM N ETWORKS
YT −1 , rT −1,T /σT −1 is the set of all M possible
prediction and target tuples across all N assets and In this section, we examine a variety of architec-
T time steps. For the binary classification tures that can be used in Deep Momentum Networks
(i) (i)  case, I is– all of which can be easily reconfigured to generate
the indicator function I rt,t+1 /σt > 0 – making
(i) the predictions described in Section III-B. This is
Yt the estimated probability of a positive return.
achieved by implementing the models using the
This still leaves us to specify how trend estimates
Keras API in Tensorflow [15], where output
map to positions, and we do so using a similar form
activation functions can be flexibly interchanged
to Equation 3:
to generate the predictions of different types (e.g.
expected returns, binary probabilities, or direct posi-
Position Sizing:
tions). Arbitrary loss functions can also be defined
(12) for direct outputs, with gradients for backpropagation
(i) (i)
Regression Xt = sgn(Yt )
(i) (i) being easily computed using the built-in libraries for
Classification Xt = sgn(Yt − 0.5) (13)
automatic differentiation.
As such, we take a maximum long position when
the expected returns are positive in the regression A. Network Architectures
case, or when the probability of a positive return is Lasso Regression: In the simplest case, a
greater than 0.5 in the classification case. standard linear model could be used to generate
predictions as below:
Direct Outputs: An alternative approach is to (i)
 
T (i)
use machine learning models to generate positions Zt = g w ut−τ :t + b , (17)
directly – simultaneously learning both trend esti- n o
mation and position sizing in the same function, where Zt(i) ∈ Xt(i) , Yt(i) depending on the pre-
i.e.: diction task, w is a weight vector for the linear
 
Xt = f ut ; θ . (14) model, and b is a bias term. Here g(·) is a activation
(i) (i)
Direct Outputs:
function which depends on the specific prediction
Given the lack of direct information on the optimal type – linear for standard regression, sigmoid for
positions to hold at each step – which is required binary classification, and tanh-function for direct
to produce labels for standard regression and classi- outputs.
fication models – calibration would hence need to Additional regularisation is also provided during
be performed by directly optimising performance training by augmenting the various loss functions to
metrics. Specifically, we focus on optimising the include an additional L1 regulariser as below:

5
Here W and V are weight matrices associated with
L̃(θθ ) = L(θθ ) + α||w||1 , (18) the gated activation function, and A and b are the
weights and biases used to transform the u to match
where L(θθ ) corresponds to one of the loss functions dimensionality of the layer outputs for the skip
described in Section III-B, ||w||1 is the L1 norm connection. The equations for WaveNet architecture
of w, and α is a constant term which we treat as used in our investigations can then be expressed as:
an additional hyperparameter. To incorporate recent
history into predictions as well, we concatenate (i) (i)
sweekly (t) = ψ(ut−5:t ) (22)
inputs over the past τ -days into a single input vector  (i) 
(i) (i) T (i) T
– i.e. ut−τ :t = [ut−τ , . . . , ut ]T . This was fixed sweekly (t)
 s(i) (t − 5) 
to be τ = 5 days for tests in Section V. (i)  weekly 
smonthly (t) = ψ  (i)  (23)
sweekly (t − 10)
Multilayer Perceptron (MLP): Increasing the (i)
sweekly (t − 15)
degree of model complexity slightly, a 2-layer neural  (i)

network can be used to incorporated non-linear smonthly (t)
(i)  
effects: squarterly (t) = ψ s(i)
monthly (t − 21) .(24)
  (i)
(i) (i) smonthly (t − 42)
ht = tanh Wh ut−τ :t + bh (19)
 
(20) Here each intermediate layer s(i) . (t) aggregates
(i) (i)
Zt = g Wz ht + bz ,
representations at weekly, monthly and quarterly
where ht is the hidden state of the MLP using frequencies respectively. Intermediate layers are then
(i)

an internal tanh activation function, tanh(·), and concatenated at each layer before passing through a
W. and b. are layer weight matrices and biases 2-layer MLP to generate outputs, i.e.:
respectively.  (i) 
sweekly (t)
(i)  
WaveNet: More modern techniques such as st =  s(i) monthly (t)  (25)
convolutional neural networks (CNNs) have been (i)
squarterly (t)
used in the domain of time series prediction – par- (i) (i)
ticularly in the form of autoregressive architectures ht = tanh(Wh st + bh ) (26)
 
e.g. [19]. These typically take the form of 1D causal (i)
Zt = g Wz ht + bz .
(i)
(27)
convolutions, sliding convolutional filters across time
to extract useful representations which are then (i)
aggregated in higher layers of the network. To State sizes for each intermediate layers sweekly (t),
increase the size of the receptive field – or the length s(i) (i)
monthly (t), squarterly (t) and the MLP hidden state
of history fed into the CNN – dilated CNNs such as h(i) are fixed to be the same, allowing us to use a
t
WaveNet [40] have been proposed, which skip over single hyperparameter to define the architecture. To
inputs at intermediate levels with a predetermined independently evaluate the performance of CNN
dilation rate. This allows it to effectively increase and RNN architectures, the above also excludes the
the amount of historical information used by the LSTM block (i.e. the context stack) described in
CNN without a large increase in computational cost. [40], focusing purely on the merits of the dilated
Let us consider a dilated convolutional layer with CNN model.
residual connections take the form below:
ψ(u) = tanh(Wu) σ(Vu) Long Short-term Memory (LSTM): Tradition-
| {z } ally used in sequence prediction for natural language
Gated Activation
processing, recurrent neural networks – specifically
+ Au
| {z+ b} . (21) long short-term memory (LSTM) architectures [39]
Skip Connection – have been increasing used in time series prediction

6
tasks. The equations for the LSTM in our model are during training. This was applied to the inputs and
provided below: hidden state for the MLP, as well as the inputs,
(i) (i) (i) Equation (22), and outputs, Equation (26), of the
ft = σ(Wf ut + Vf ht−1 + bf ) (28) convolutional layers in the WaveNet architecture.
(i) (i) (i)
it = σ(Wi ut + Vi ht−1 + bi ) (29) For the LSTM, we adopted the same dropout masks
(i) (i) (i) as in [46] – applying dropout to the RNN inputs,
ot = σ(Wo ut + Vo ht−1 + bo ) (30)
recurrent states and outputs.
(i) (i) (i)
ct = ft ct−1
(i) (i) (i)
+ it tanh(Wc ut + Vc ht−1 + bc ) (31)
V. P ERFORMANCE E VALUATION
(i) (i) (i)
ht = ot tanh(ct ) (32) A. Overview of Dataset
 
(i) (i)
Zt = g Wz ht + bz , (33) The predictive performance of the different archi-
tectures was evaluated via a backtest using 88 ratio-
where is the Hadamard (element-wise) product,
adjusted continuous futures contracts downloaded
σ(.) is the sigmoid activation function, W. and
from the Pinnacle Data Corp CLC Database [47].
V. are weight matrices for the different layers,
(i) (i) (i) These contracts spanned across a variety of asset
ft , it , ot correspond to the forget, input and
(i) classes – including commodities, fixed income and
output gates respectively, ct is the cell state, and currency futures – and contained prices from 1990
(i)
ht is the hidden state of the LSTM. From these to 2015. A full breakdown of the dataset can be
equations, we can see that the LSTM uses the cell found in Appendix A.
state as a compact summary of past information,
controlling memory retention with the forget gate B. Backtest Description
and incorporating new information via the input gate. Throughout our backtest, the models were recal-
As such, the LSTM is able to learn representations ibrated from scratch every 5 years – re-running
of long-term relationships relevant to the prediction the entire hyperparameter optimisation procedure
task – sequentially updating its internal memory using all data available up to the recalibration point.
states with new observations at each step. Model weights were then fixed for signals generated
B. Training Details over the next 5 year period, ensuring that tests were
Model calibration was undertaken using minibatch performed out-of-sample.
stochastic gradient descent with the Adam optimiser For the Deep Momentum Networks, we incorpo-
[44], based on the loss functions defined in Section rate a series of useful features adopted by standard
III-B. Backpropagation was performed up to a time series momentum strategies in Section III-A to
maximum of 100 training epochs using 90% of a generate predictions at each step:
given block of training data, and the most recent 1) Normalised Returns – Returns over the past
10% retained as a validation dataset. Validation day, 1-month, 3-month, 6-month and 1-year
data is then used to determine convergence – with periods are used, normalised by a measure of
early stopping triggered when the validation loss daily volatility scaled to an appropriate time
has not improved for 25 epochs – and to identify scale. For instance, normalised√annual returns
(i) (i)
the optimal model across hyperparameter settings. were taken to be rt−252,t /(σt 252).
Hyperparameter optimisation was conducted using 2) MACD Indicators – We also include the
(i)
50 iterations of random search, with full details MACD indicators – i.e. trend estimates Yt –
provided in Appendix B. For additional information as in Equation (4), using the same short time-
on the deep neural network calibration, please refer scales Sk ∈ {8, 16, 32} and long time-scales
to [13]. Lk ∈ {24, 48, 96}.
Dropout regularisation [45] was a key feature For comparisons against traditional time series mo-
to avoid overfitting in the neural network models mentum strategies, we also incorporate the following
– with dropout rates included as hyperparameters reference benchmarks:

7
(i)
1) Long Only with Volatility Scaling (Xt = 1) Additional model complexity, however, does not
2) Sgn(Returns) – Moskowitz et al. 2012 [1] necessarily lead to better predictive performance, as
3) MACD Signal – Baz et al. 2015 [4] demonstrated by the underperformance of WaveNet
Finally, performance was judged based on the compared to both the reference benchmarks and
following metrics: simple linear models. Part of this can be attributed
1) Profitability – Expected returns (E[Returns]) to the difficulties in tuning models with multiple
and the percentage of positive returns observed design parameters - for instance, better results could
across the test period. possibly achieved by using alternative dilation rates,
2) Risk – Daily volatility (Vol.), downside devia- number of convolutional layers, and hidden state
tion and the maximum drawdown (MDD) of sizes in Equations (22) to (24) for the WaveNet. In
the overall portfolio. contrast, only a single design parameter is sufficient
3) Performance Ratios – Risk adjusted perfor- to specify the hidden state size in both the MLP and
mance was LSTM models. Analysing the relative performance
  measured by  the Sharpe ratio
E[Returns] E[Returns]
, Sortino ratio Downside and within each model class, we can see that models
Vol. Deviation
  which directly generate positions perform the best –
Calmar ratio E[Returns]
MDD
, as well as the average demonstrating the benefits of simultaneous learning

profit over the average loss Ave. Ave. P
. both trend estimation and position sizing functions.
L
In addition, with the exception of a slight decrease
C. Results and Discussion in the MLP, Sharpe-optimised models outperform
Aggregating the out-of-sample predictions from returns-optimised ones, with standard regression and
1995 to 2015, we compute performance metrics classification benchmarks taking third and fourth
for both the strategy returns based on Equation (1) place respectively.
(Exhibit 2), as well as that for portfolios with an From Exhibit 3, while the addition of volatility
additional layer of volatility scaling – which brings scaling at the portfolio level improved performance
overall strategy returns to match the 15% volatility ratios on the whole, it had a larger beneficial effect on
target (Exhibit 3). Given the large differences in machine learning models compared to the reference
returns volatility seen in Table 2, this rescaling benchmarks – propelling Sharpe-optimised MLPs to
also helps to facilitate comparisons between the outperform returns-optimised ones, and even leading
cumulative returns of different strategies – which to Sharpe-optimised linear models beating reference
are plotted for various loss functions in Exhibit benchmarks. From a risk perspective, we can see that
4. We note that strategy returns in this section both volatility and downside deviation also become
are computed in the absence of transaction costs, a lot more comparable, with the former hovering
allowing us to focus on the raw predictive ability of close to 15.5% and the latter around 10%. However,
the models themselves. The impact of transaction Sharpe-optimised LSTMs still retained the lowest
costs is explored further in Section VI, where we MDD across all models, with superior risk-adjusted
undertake a deeper analysis of signal turnover. More performance ratios across the board. Referring to the
detailed results can also be found in Appendix C, cumulative returns plots for the rescaled portfolios in
which echo the findings below. Exhibit 4, the benefits of direct outputs with Sharpe
Focusing on the raw signal outputs, the Sharpe ratio optimisation can also be observed – with larger
ratio-optimised LSTM outperforms all benchmarks cumulative returns observed for linear, MLP and
as expected, improving the best neural network LSTM models compared to the reference bench-
model (Sharpe-optimised MLP) by 44% and the marks. Furthermore, we note the general underperfor-
best reference benchmark (Sgn(Returns)) by more mance of models which use standard regression and
than two times. In conjunction with Sharpe ratio classification methods for trend estimation – hinting
improvements to both the linear and MLP models, at the difficulties faced in selecting an appropriate
this highlights the benefits of using models which position sizing function, and in optimising models
capture non-linear relationships, and have access to generate positions without accounting for risk.
to more time history via an internal memory state. This is particularly relevant for binary classification

8
Exhibit 2: Performance Metrics – Raw Signal Outputs
Downside % of +ve Ave. P
E[Return] Vol. MDD Sharpe Sortino Calmar
Deviation Returns Ave. L

Reference
Long Only 0.039 0.052 0.035 0.167 0.738 1.086 0.230 53.8% 0.970
Sgn(Returns) 0.054 0.046 0.032 0.083 1.192 1.708 0.653 54.8% 1.011
MACD 0.030 0.031 0.022 0.081 0.976 1.356 0.371 53.9% 1.015
Linear
Sharpe 0.041 0.038 0.028 0.119 1.094 1.462 0.348 54.9% 0.997
Ave. Returns 0.047 0.045 0.031 0.164 1.048 1.500 0.287 53.9% 1.022
MSE 0.049 0.047 0.032 0.164 1.038 1.522 0.298 54.3% 1.000
Binary 0.013 0.044 0.030 0.167 0.295 0.433 0.078 50.6% 1.028
MLP
Sharpe 0.044 0.031 0.025 0.154 1.383 1.731 0.283 56.0% 1.024
Ave. Returns 0.064* 0.043 0.030 0.161 1.492 2.123 0.399 55.6% 1.031
MSE 0.039 0.046 0.032 0.166 0.844 1.224 0.232 52.7% 1.035
Binary 0.003 0.042 0.028 0.233 0.080 0.120 0.014 50.8% 0.981
WaveNet
Sharpe 0.030 0.035 0.026 0.101 0.854 1.167 0.299 53.5% 1.008
Ave. Returns 0.032 0.040 0.028 0.113 0.788 1.145 0.281 53.8% 0.980
MSE 0.022 0.042 0.028 0.134 0.536 0.786 0.166 52.4% 0.994
Binary 0.000 0.043 0.029 0.313 0.011 0.016 0.001 50.2% 0.995
LSTM
Sharpe 0.045 0.016* 0.011* 0.021* 2.804* 3.993* 2.177* 59.6%* 1.102*
Ave. Returns 0.054 0.046 0.033 0.164 1.165 1.645 0.326 54.8% 1.003
MSE 0.031 0.046 0.032 0.163 0.669 0.959 0.189 52.8% 1.003
Binary 0.012 0.039 0.026 0.255 0.300 0.454 0.046 51.0% 1.012

Exhibit 3: Performance Metrics – Rescaled to Target Volatility


Downside % of +ve Ave. P
E[Return] Vol. MDD Sharpe Sortino Calmar
Deviation Returns Ave. L

Reference
Long Only 0.117 0.154 0.102 0.431 0.759 1.141 0.271 53.8% 0.973
Sgn(Returns) 0.215 0.154 0.102 0.264 1.392 2.108 0.815 54.8% 1.041
MACD 0.172 0.155 0.106 0.317 1.111 1.622 0.543 53.9% 1.031
Linear
Sharpe 0.232 0.155 0.103 0.303 1.496 2.254 0.765 54.9% 1.056
Ave. Returns 0.189 0.154 0.100 0.372 1.225 1.893 0.507 53.9% 1.047
MSE 0.186 0.154 0.099* 0.365 1.211 1.889 0.509 54.3% 1.025
Binary 0.051 0.155 0.103 0.558 0.332 0.496 0.092 50.6% 1.033
MLP
Sharpe 0.312 0.154 0.102 0.335 2.017 3.042 0.930 56.0% 1.104
Ave. Returns 0.266 0.154 0.099* 0.354 1.731 2.674 0.752 55.6% 1.065
MSE 0.156 0.154 0.099* 0.371 1.017 1.582 0.422 52.7% 1.062
Binary 0.017 0.154 0.102 0.661 0.108 0.162 0.025 50.8% 0.986
WaveNet
Sharpe 0.148 0.155 0.103 0.349 0.956 1.429 0.424 53.5% 1.018
Ave. Returns 0.136 0.154 0.101 0.356 0.881 1.346 0.381 53.8% 0.993
MSE 0.084 0.153* 0.101 0.459 0.550 0.837 0.184 52.4% 0.995
Binary 0.007 0.155 0.103 0.779 0.045 0.068 0.009 50.2% 1.001
LSTM
Sharpe 0.451* 0.155 0.105 0.209* 2.907* 4.290* 2.159* 59.6%* 1.113*
Ave. Returns 0.208 0.154 0.102 0.365 1.349 2.045 0.568 54.8% 1.028
MSE 0.121 0.154 0.100 0.362 0.791 1.211 0.335 52.8% 1.020
Binary 0.075 0.155 0.099* 0.682 0.486 0.762 0.110 51.0% 1.043

9
Exhibit 4: Cumulative Returns - Rescaled to Target Volatility

(a) Sharpe Ratio (b) Average Returns

(c) MSE (d) Binary

methods, which produce relatively flat equity lines we investigate the performance constituents of the
and underperform reference benchmarks in general. time series momentum portfolios – using box plots
Some of these poor results can be explained by for a variety of performance metrics, plotting the
the implicit decision threshold adopted. From the minimum, lower quartile, median, upper quartile, and
percentage of positive returns captured in Exhibit maximum values across individual futures contracts.
3, most binary classification models have about a We present in Exhibit 5 plots of one metric per
50% accuracy which, while expected of a classifier category in Section V-B, although similar results can
with a 0.5 probability threshold, is far below the be seen for other performance ratios are documented
accuracies seen in other benchmarks. Furthermore, in Appendix C. In general, the Sharpe ratio plots
performance is made worse by the fact that the in Exhibit 5a echo previous findings, with direct
Ave. P
model’s magnitude of gains versus losses Ave. L
output methods performing better than indirect trend
is much smaller than competing methods – with estimation models. However, as seen in Exhibit 5c,
average loss magnitudes even outweighing
 profits for this is mainly attributable to significant reduction in
Ave. P
the MLP classifier Ave. L = 0.986 . As such, these signal volatility for the Sharpe-optimised methods,
observations lend support to the direct generation of despite a comparable range of average returns in
positions sizes with machine learning methods, given Exhibit 5b. The benefits of retaining the volatility
the multiple considerations (e.g. decision thresholds scaling can also be observed, with individual signal
and profit/loss magnitudes) that would be required volatility capped near the target across all methods
to incorporate standard supervising learning methods – even with a naive sgn(.) position sizer. As such,
into a profitable trading strategy. the combination of volatility scaling, direct outputs
Strategy performance could also be aided by and Sharpe ratio optimisation were all key to
diversification across a range of assets, particularly performance gains in Deep Momentum Networks.
when the correlation between signals is low. Hence,
to evaluate the raw quality of the underlying signal,

10
Exhibit 5: Performance Across Individual Assets

(a) Sharpe Ratio

(b) Average Returns

(c) Volatility

11
VI. T URNOVER A NALYSIS
To investigate how transaction costs affect strategy T SM OM
r̃t,t+1 =
performance, we first analyse the daily position N  (i) 
σtgt X t (i) (i)
changes of the signal – characterised for asset i Xt (i) Xt Xt−1
r − c (i) − (i) , (35)
(i)
by daily turnover ζt as defined in [8]: Nt i=1 σt(i) t,t+1 σt σt−1

Xt
(i) (i)
Xt−1 where c is a constant reflecting transaction cost
(i)
ζt = σtgt (i) − (i) (34) assumptions. As such, using r̃t,t+1
T SM OM
in Sharpe
σt σt−1 ratio loss functions during training corresponds to
optimising the ex-cost risk-adjusted returns, and
Which is broadly proportional to the volume of (i) (i)
Xt Xt−1
asset i traded on day t with reference to the updated c σ(i) − (i) can also be interpreted as a regu-
σt−1
portfolio weights. t
larisation term for turnover.
Exhibit 6a shows the average strategy turnover Given that the Sharpe-optimised LSTM is still
across all assets from 1995 to 2015, focusing on profitable in the presence of small transactions costs,
positions generated by the raw signal outputs. As the we seek to quantify the effectiveness of turnover
box plots are charted on a logarithm scale, we note regularisation when costs are prohibitively high –
that while the machine learning-based models have considering the extreme case where c = 10bps in
a similar turnover, they also trade significantly more our investigation. Tests were focused on the Sharpe-
than the reference benchmarks – approximately 10 optimised LSTM with and without the turnover
times more compared to the Long Only benchmark. regulariser (LSTM + Reg. for the former) – including
This is also reflected in Exhibit 6a which compares the additional portfolio level volatility scaling to
the average daily returns against the average daily bring signal volatilities to the same level. Based on
turnover – with ratios from machine learning models the results in Exhibit 8, we can see that the turnover
lying close to the x-axis. regularisation does help improve the LSTM in the
To concretely quantify the impact of transaction presence of large costs, leading to slightly better
costs on performance, we also compute the ex- performance ratios when compared to the reference
cost Sharpe ratios – using the rebalancing costs benchmarks.
defined in [8] to adjust our returns for a variety
of transaction cost assumptions . For the results VII. C ONCLUSIONS
in Exhibit 7, the top of each bar chart marks the We introduce Deep Momentum Networks – a
maximum cost-free Sharpe ratio of the strategy, hybrid class of deep learning models which retain
with each coloured block denoting the Sharpe ratio the volatility scaling framework of time series mo-
reduction for the corresponding cost assumption. mentum strategies while using deep neural networks
In line with the turnover analysis, the reference to output position targeting trading signals. Two
benchmarks demonstrate the most resilience to high approaches to position generation were evaluated
transaction costs (up to 5bps), with the profitability here. Firstly, we cast trend estimation as a standard
across most machine learning models persisting only supervised learning problem – using machine learn-
up to 4bps. However, we still obtain higher cost- ing models to forecast the expected asset returns or
adjusted Sharpe ratios with the Sharpe-optimised probability of a positive return at the next time step –
LSTM for up to 2-3 bps, demonstrating its suitability and apply a simple maximum long/short trading rule
for trading more liquid instruments. based on the direction of the next return. Secondly,
trading rules were directly generated as outputs
A. Turnover Regularisation from the model, which we calibrate by maximising
One simple way to account for transaction costs is the Sharpe ratio or average strategy return. Testing
to use cost-adjusted returns r̃t,t+1
T SM OM
directly during this on a universe of continuous futures contracts,
training, augmenting the strategy returns defined in we demonstrate clear improvements in risk-adjusted
Equation (1) as below: performance by calibrating models with the Sharpe

12
Exhibit 6: Turnover Analysis

(a) Average Strategy Turnover

(b) Average Returns / Average Turnover

Exhibit 7: Impact of Transaction Costs on Sharpe Ratio

Exhibit 8: Performance Metrics with Transaction Costs (c = 10bps)


Downside % of +ve Ave. P
E[Return] Vol. MDD Sharpe Sortino Calmar
Deviation Returns Ave. L

Long Only 0.097 0.154* 0.103 0.482 0.628 0.942 0.201 53.3% 0.970
Sgn(Returns) 0.133 0.154* 0.102* 0.373 0.861 1.296 0.356 53.3% 1.011
MACD 0.111 0.155 0.106 0.472 0.719 1.047 0.236 52.5% 1.020*
LSTM -0.833 0.157 0.114 1.000 -5.313 -7.310 -0.833 33.9% 0.793
LSTM + Reg. 0.141* 0.154* 0.102* 0.371* 0.912* 1.379* 0.379* 53.4%* 1.014

13
ratio – where the LSTM model achieved best results. R EFERENCES
Incorporating transaction costs, the Sharpe-optimised [1] T. J. Moskowitz, Y. H. Ooi, and L. H. Pedersen, “Time series
LSTM outperforms benchmarks up to 2-3 basis momentum,” Journal of Financial Economics, vol. 104, no. 2,
pp. 228 – 250, 2012, Special Issue on Investor Sentiment.
points of costs, demonstrating its suitability for [2] B. Hurst, Y. H. Ooi, and L. H. Pedersen, “A century of
trading more liquid assets. To accommodate high evidence on trend-following investing,” The Journal of Portfolio
costs settings, we introduce a turnover regulariser to Management, vol. 44, no. 1, pp. 15–29, 2017.
[3] Y. Lempérière, C. Deremble, P. Seager, M. Potters, and J.-
use during training, which was shown to be effective P. Bouchaud, “Two centuries of trend following,” Journal of
even in extreme scenarios (i.e. c = 10bps). Investment Strategies, vol. 3, no. 3, pp. 41–61, 2014.
Future work includes extensions of the framework [4] J. Baz, N. Granger, C. R. Harvey, N. Le Roux, and
S. Rattray, “Dissecting investment strategies in the cross
presented here to incorporate ways to deal better with section and time series,” SSRN, 2015. [Online]. Available:
non-stationarity in the data, such as using the recently https://ssrn.com/abstract=2695101
introduced Recurrent Neural Filters [48]. Another [5] A. Levine and L. H. Pedersen, “Which trend is your friend,”
Financial Analysts Journal, vol. 72, no. 3, 2016.
direction of future work focuses on the study of time [6] B. Bruder, T.-L. Dao, J.-C. Richard, and T. Roncalli, “Trend
series momentum at the microstructure level. filtering methods for momentum strategies,” SSRN, 2013.
[Online]. Available: https://ssrn.com/abstract=2289097
VIII. ACKNOWLEDGEMENTS [7] A. Y. Kim, Y. Tse, and J. K. Wald, “Time series momentum
and volatility scaling,” Journal of Financial Markets, vol. 30,
We would like to thank Anthony Ledford, James pp. 103 – 124, 2016.
Powrie and Thomas Flury for their interesting [8] N. Baltas and R. Kosowski, “Demystifying time-series
comments as well the Oxford-Man Institute of momentum strategies: Volatility estimators, trading rules
and pairwise correlations,” SSRN, 2017. [Online]. Available:
Quantitative Finance for financial support. https://ssrn.com/abstract=2140091
[9] C. R. Harvey, E. Hoyle, R. Korgaonkar, S. Rattray, M. Sargaison,
and O. van Hemert, “The impact of volatility targeting,” SSRN,
2018. [Online]. Available: https://ssrn.com/abstract=3175538
[10] N. Laptev, J. Yosinski, L. E. Li, and S. Smyl, “Time-series
extreme event forecasting with neural networks at uber,” in
Time Series Workshop – International Conference on Machine
Learning (ICML), 2017.
[11] B. Lim and M. van der Schaar, “Disease-atlas: Navigating
disease trajectories using deep learning,” in Proceedings of
the 3rd Machine Learning for Healthcare Conference (MLHC),
ser. Proceedings of Machine Learning Research, vol. 85, 2018,
pp. 137–160.
[12] Z. Zhang, S. Zohren, and S. Roberts, “DeepLOB: Deep
convolutional neural networks for limit order books,” IEEE
Transactions on Signal Processing, 2019.
[13] I. Goodfellow, Y. Bengio, and A. Courville, Deep Learning.
MIT Press, 2016, http://www.deeplearningbook.org.
[14] Y. Bengio, A. Courville, and P. Vincent, “Representation
learning: A review and new perspectives,” IEEE Transactions
on Pattern Analysis and Machine Intelligence, vol. 35, no. 8,
pp. 1798–1828, 2013.
[15] M. Abadi et al., “TensorFlow: Large-scale machine learning
on heterogeneous systems,” 2015, software available from
tensorflow.org. [Online]. Available: https://www.tensorflow.org/
[16] A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito,
Z. Lin, A. Desmaison, L. Antiga, and A. Lerer, “Automatic
differentiation in PyTorch,” in Autodiff Workshop – Conference
on Neural Information Processing (NIPS), 2017.
[17] S. Makridakis, E. Spiliotis, and V. Assimakopoulos, “The M4
competition: Results, findings, conclusion and way forward,”
International Journal of Forecasting, vol. 34, no. 4, pp. 802 –
808, 2018.
[18] S. Smyl, J. Ranganathan, , and A. Pasqua. (2018) M4 forecasting
competition: Introducing a new hybrid es-rnn model. [Online].
Available: https://eng.uber.com/m4-forecasting-competition/
[19] M. Binkowski, G. Marti, and P. Donnat, “Autoregressive
convolutional neural networks for asynchronous time series,” in

14
Proceedings of the 35th International Conference on Machine [39] S. Hochreiter and J. Schmidhuber, “Long short-term memory,”
Learning, ser. Proceedings of Machine Learning Research, Neural computation, vol. 9, no. 8, pp. 1735–1780, 1997.
vol. 80, 2018, pp. 580–589. [40] A. van den Oord, S. Dieleman, H. Zen, K. Simonyan, O. Vinyals,
[20] S. S. Rangapuram, M. W. Seeger, J. Gasthaus, L. Stella, Y. Wang, A. Graves, N. Kalchbrenner, A. W. Senior, and K. Kavukcuoglu,
and T. Januschowski, “Deep state space models for time series “WaveNet: A generative model for raw audio,” CoRR, vol.
forecasting,” in Advances in Neural Information Processing abs/1609.03499, 2016.
Systems 31 (NeurIPS), 2018. [41] D. Silver, J. Schrittwieser, K. Simonyan, I. Antonoglou,
[21] M. Fraccaro, S. Kamronn, U. Paquet, and O. Winther, “A A. Huang, A. Guez, T. Hubert, L. Baker, M. Lai, A. Bolton,
disentangled recognition and nonlinear dynamics model for Y. Chen, T. Lillicrap, F. Hui, L. Sifre, G. van den Driessche,
unsupervised learning,” in Advances in Neural Information T. Graepel, and D. Hassabis, “Mastering the game of Go without
Processing Systems 30 (NIPS), 2017. human knowledge,” Nature, vol. 550, pp. 354–, 2017.
[22] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde- [42] P. N. Kolm and G. Ritter, “Dynamic replication and hedging:
Farley, S. Ozair, A. Courville, and Y. Bengio, “Generative A reinforcement learning approach,” The Journal of Financial
adversarial nets,” in Advances in Neural Information Processing Data Science, vol. 1, no. 1, pp. 159–171, 2019.
Systems 27 (NIPS), 2014. [43] H. Bühler, L. Gonon, J. Teichmann, and B. Wood, “Deep
[23] S. Gu, B. T. Kelly, and D. Xiu, “Empirical asset pricing via Hedging,” arXiv e-prints, p. arXiv:1802.03042, 2018.
machine learning,” Chicago Booth Research Paper No. 18-04; [44] D. Kingma and J. Ba, “Adam: A method for stochastic optimiza-
31st Australasian Finance and Banking Conference 2018, 2017. tion,” in International Conference on Learning Representations
[Online]. Available: https://ssrn.com/abstract=3159577 (ICLR), 2015.
[24] S. Kim, “Enhancing the momentum strategy through deep [45] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and
regression,” Quantitative Finance, vol. 0, no. 0, pp. 1–13, 2019. R. Salakhutdinov, “Dropout: A simple way to prevent neu-
[25] J. Sirignano and R. Cont, “Universal features of price formation ral networks from overfitting,” Journal of Machine Learning
in financial markets: Perspectives from deep learning,” SSRN, Research, vol. 15, pp. 1929–1958, 2014.
2018. [Online]. Available: https://ssrn.com/abstract=3141294 [46] Y. Gal and Z. Ghahramani, “A theoretically grounded application
[26] S. Ghoshal and S. Roberts, “Thresholded ConvNet ensembles: of dropout in recurrent neural networks,” in Advances in Neural
Neural networks for technical forecasting,” in Data Science in Information Processing Systems 29 (NIPS), 2016.
Fintech Workshop – Conference on Knowledge Discover and [47] “Pinnacle Data Corp. CLC Database,” https://pinnacledata2.com/
Data Mining (KDD), 2018. clc.html.
[27] W. Bao, J. Yue, and Y. Rao, “A deep learning framework for [48] B. Lim, S. Zohren, and S. Roberts, “Recurrent Neural Filters:
financial time series using stacked autoencoders and long-short Learning Independent Bayesian Filtering Steps for Time Series
term memory,” PLOS ONE, vol. 12, no. 7, pp. 1–24, 2017. Prediction,” arXiv e-prints, p. arXiv:1901.08096, 2019.
[28] P. Barroso and P. Santa-Clara, “Momentum has its moments,”
Journal of Financial Economics, vol. 116, no. 1, pp. 111 – 120,
2015.
[29] K. Daniel and T. J. Moskowitz, “Momentum crashes,” Journal
of Financial Economics, vol. 122, no. 2, pp. 221 – 247, 2016.
[30] R. Martins and D. Zou, “Momentum strategies offer a positive
point of skew,” Risk Magazine, 2012.
[31] P. Jusselin, E. Lezmi, H. Malongo, C. Masselin, T. Roncalli,
and T.-L. Dao, “Understanding the momentum risk premium:
An in-depth journey through trend-following strategies,” SSRN,
2017. [Online]. Available: https://ssrn.com/abstract=3042173
[32] M. Potters and J.-P. Bouchaud, “Trend followers lose more than
they gain,” Wilmott Magazine, 2016.
[33] L. M. Rotando and E. O. Thorp, “The Kelly criterion and the
stock market,” The American Mathematical Monthly, vol. 99,
no. 10, pp. 922–931, 1992.
[34] W. F. Sharpe, “The sharpe ratio,” The Journal of Portfolio
Management, vol. 21, no. 1, pp. 49–58, 1994.
[35] N. Jegadeesh and S. Titman, “Returns to buying winners and
selling losers: Implications for stock market efficiency,” The
Journal of Finance, vol. 48, no. 1, pp. 65–91, 1993.
[36] ——, “Profitability of momentum strategies: An evaluation of
alternative explanations,” The Journal of Finance, vol. 56, no. 2,
pp. 699–720, 2001.
[37] J. Rohrbach, S. Suremann, and J. Osterrieder, “Momentum
and trend following trading strategies for currencies revisited
- combining academia and industry,” SSRN, 2017. [Online].
Available: https://ssrn.com/abstract=2949379
[38] Z. Zhang, S. Zohren, and S. Roberts, “BDLOB: Bayesian deep
convolutional neural networks for limit order books,” in Bayesian
Deep Learning Workshop – Conference on Neural Information
Processing (NeurIPS), 2018.

15
A PPENDIX 2) Equities:
A. Dataset Details
Identifier Description
From the full 98 ratio-adjusted continuous futures
AX GERMAN DAX INDEX
contracts in the Pinnacle Data Corp CLC Database, CA CAC40 INDEX
we extract 88 which have < 10% of its data missing EN NASDAQ, MINI
– with a breakdown by asset class below: ER RUSSELL 2000, MINI
ES S & P 500, MINI
1) Commodities: HS HANG SENG
LX FTSE 100 INDEX
Identifier Description MD S&P 400 (Mini electronic)
SC S & P 500, composite
BC BRENT CRUDE OIL, composite SP S & P 500, day session
BG BRENT GASOIL, comp. XU DOW JONES EUROSTOXX50
BO SOYBEAN OIL XX DOW JONES STOXX 50
CC COCOA YM Mini Dow Jones ($5.00)
CL CRUDE OIL
CT COTTON #2
C CORN 3) Fixed Income:
DA MILK III, Comp.
FC FEEDER CATTLE
GC GOLD (COMMEX) Identifier Description
GI GOLDMAN SAKS C. I.
HG COPPER AP AUSTRALIAN PRICE INDEX
HO HEATING OIL #2 DT EURO BOND (BUND)
JO ORANGE JUICE FA T-NOTE, 5yr day session
KC COFFEE FB T-NOTE, 5yr composite
KW WHEAT, KC GS GILT, LONG BOND
LB LUMBER TA T-NOTE, 10yr day session
LC LIVE CATTLE TD T-NOTES, 2yr day session
LH LIVE HOGS TU T-NOTES, 2yr composite
MW WHEAT, MINN TY T-NOTE, 10yr composite
NG NATURAL GAS UA T-BONDS, day session
NR ROUGH RICE UB EURO BOBL
O OATS US T-BONDS, composite
PA PALLADIUM
PL PLATINUM
RB RBOB GASOLINE 4) FX:
SB SUGAR #11
SI SILVER (COMMEX) Identifier Description
SM SOYBEAN MEAL
S SOYBEANS AD AUSTRALIAN $$, day session
W WHEAT, CBOT AN AUSTRALIAN $$, composite
ZA PALLADIUM, electronic BN BRITISH POUND, composite
ZB RBOB, Electronic CB CANADIAN 10YR BOND
ZC CORN, Electronic CN CANADIAN $$, composite
ZF FEEDER CATTLE, Electronic DX US DOLLAR INDEX
ZG GOLD, Electronic FN EURO, composite
ZH HEATING OIL, electronic FX EURO, day session
ZI SILVER, Electronic JN JAPANESE YEN, composite
ZK COPPER, electronic MP MEXICAN PESO
ZL SOYBEAN OIL, Electronic NK NIKKEI INDEX
ZM SOYBEAN MEAL, Electronic SF SWISS FRANC, day session
ZN NATURAL GAS, electronic SN SWISS FRANC, composite
ZO OATS, Electronic
ZP PLATINUM, electronic
ZR ROUGH RICE, Electronic To reduce the impact of outliers, we also winsorise
ZS SOYBEANS, Electronic the data by capping/flooring it to be within 5 times
ZT LIVE CATTLE, Electronic
ZU CRUDE OIL, Electronic its exponentially weighted moving (EWM) standard
ZW WHEAT, Electronic deviations from its EWM average – computed using
ZZ LEAN HOGS, Electronic a 252-day half life.

16
Exhibit 9: Hyperparameter Search Range
Hyperparameters Random Search Grid Notes
Dropout Rate 0.1, 0.2, 0.3, 0.4, 0.5 Neural Networks Only
Hidden Layer Size 5, 10, 20, 40, 80 Neural Networks Only
Minibatch Size 256, 512, 1024, 2048
Learning Rate 10−5 , 10−4 , 10−3 , 10−2 , 10−1 , 100
Max Gradient Norm 10−4 , 10−3 , 10−2 , 10−1 , 100 , 101
L1 Regularisation Weight (α) 10−5 , 10−4 , 10−3 , 10−2 , 10−1 Lasso Regression Only

B. Hyperparameter Optimisation • Exhibit 12 – Full performance ratios (Sharpe,


Sortino, Calmar).
Hyperparameter optimisation was applied using
• Exhibit 13 – Box plots including additional
50 iterations of random search, with the full search
risk metrics (volatility, downside deviation and
grid documented in Exhibit 9, with the models fully
maximum drawdown). We note that our findings
recalibrated every 5 years using all available data up
for other risk metrics are similar to volatility
to that point. For LSTM-based models, time series
– with Sharpe-optimised models lowering risk
were subdivided into trajectories of 63 time steps
across different methods.
(≈ 3 months), with the LSTM unrolled across the
length of the trajectory during backpropagation.

C. Additional Results
In addition to the selected results in Section V,
we also present a full list of results was presented
for completeness – echoing the key findings reported
in the discuss. Detailed descriptions of the plots and
tables can be found below:
1) Cross-Validation Performance: The testing pro-
cedure in Section V can also be interpreted as a cross-
validation approach – splitting the original dataset
into six 5-year blocks (1990-2015), calibrating using
an expanding window of data, and testing out-of-
sample on the next block outside the training set.
As such, for consistency with machine learning
literature, we present our results in a cross validation
format as well – reporting the average value across
all blocks ± 2 standard deviations. Furthermore, this
also gives an indication of how signal performance
varies across the various time periods.
• Exhibit 10 – Cross-validation results for raw
signal outputs.
• Exhibit 11 – Cross-validation results for signals
which have been rescaled to target volatility at
the portfolio level.
2) Metrics Across Individual Assets: We also
provide additional plots on performance of other
risk metrics and performance ratios across individual
assets, as described below:

17
Exhibit 10: Cross-Validation Performance – Raw Signal Outputs
Downside
E[Return] Vol. MDD
Deviation
Reference
Long Only 0.043 ± 0.028 0.054 ± 0.016 0.037 ± 0.013 0.116 ± 0.091
Sgn(Returns) 0.047 ± 0.051 0.046 ± 0.012 0.032 ± 0.007 0.067 ± 0.041
MACD 0.026 ± 0.032 0.032 ± 0.008 0.023 ± 0.007 0.054 ± 0.048
Linear
Sharpe 0.034 ± 0.030 0.039 ± 0.028 0.028 ± 0.020 0.072 ± 0.096
Ave. Returns 0.033 ± 0.031 0.046 ± 0.025 0.031 ± 0.018 0.110 ± 0.114
MSE 0.047 ± 0.038 0.049 ± 0.019 0.033 ± 0.013 0.100 ± 0.121
Binary 0.012 ± 0.028 0.045 ± 0.011 0.031 ± 0.009 0.109 ± 0.045
MLP
Sharpe 0.038 ± 0.027 0.030 ± 0.041 0.021 ± 0.028 0.062 ± 0.160
Ave. Returns 0.056 ± 0.046* 0.044 ± 0.024 0.030 ± 0.017 0.075 ± 0.150
MSE 0.037 ± 0.051 0.048 ± 0.021 0.032 ± 0.015 0.109 ± 0.134
Binary -0.004 ± 0.028 0.042 ± 0.007 0.028 ± 0.006 0.111 ± 0.079
WaveNet
Sharpe 0.030 ± 0.030 0.038 ± 0.019 0.027 ± 0.015 0.069 ± 0.055
Ave. Returns 0.034 ± 0.043 0.042 ± 0.003 0.030 ± 0.003 0.088 ± 0.062
MSE 0.024 ± 0.046 0.043 ± 0.010 0.030 ± 0.010 0.102 ± 0.056
Binary -0.009 ± 0.023 0.043 ± 0.008 0.030 ± 0.008 0.159 ± 0.107
LSTM
Sharpe 0.045 ± 0.030 0.017 ± 0.004* 0.012 ± 0.003* 0.019 ± 0.005*
Ave. Returns 0.045 ± 0.050 0.048 ± 0.018 0.034 ± 0.011 0.104 ± 0.119
MSE 0.023 ± 0.037 0.048 ± 0.022 0.033 ± 0.017 0.116 ± 0.082
Binary -0.005 ± 0.088 0.042 ± 0.003 0.027 ± 0.006 0.151 ± 0.211

Fraction of Ave. P
Sharpe Sortino Calmar
+ve Returns Ave. L

Reference
Long Only 0.839 ± 0.786 1.258 ± 1.262 0.420 ± 0.490 0.546 ± 0.025 0.956 ± 0.135
Sgn(Returns) 1.045 ± 1.230 1.528 ± 1.966 0.864 ± 1.539 0.543 ± 0.061 1.002 ± 0.067
MACD 0.839 ± 1.208 1.208 ± 1.817 0.625 ± 1.033 0.532 ± 0.030 1.016 ± 0.079
Linear
Sharpe 1.025 ± 1.530 1.451 ± 2.154 0.800 ± 1.772 0.544 ± 0.042 1.000 ± 0.100
Ave. Returns 0.757 ± 0.833 1.150 ± 1.378 0.397 ± 0.686 0.530 ± 0.009 1.005 ± 0.100
MSE 1.012 ± 1.126 1.532 ± 1.811 0.708 ± 1.433 0.540 ± 0.024 1.008 ± 0.096
Binary 0.288 ± 0.729 0.434 ± 1.113 0.123 ± 0.313 0.506 ± 0.027 1.024 ± 0.051
MLP
Sharpe 1.669 ± 2.332 2.420 ± 3.443 1.665 ± 2.738 0.554 ± 0.063 1.069 ± 0.151
Ave. Returns 1.415 ± 1.781 2.127 ± 2.996 1.520 ± 2.761 0.553 ± 0.043 1.022 ± 0.134
MSE 0.821 ± 1.334 1.270 ± 2.160 0.652 ± 1.684 0.525 ± 0.025 1.036 ± 0.127
Binary -0.099 ± 0.648 -0.180 ± 0.956 -0.013 ± 0.304 0.500 ± 0.042 0.986 ± 0.064
WaveNet
Sharpe 0.780 ± 0.538 1.118 ± 0.854 0.477 ± 0.610 0.535 ± 0.022 0.990 ± 0.094
Ave. Returns 0.809 ± 1.113 1.160 ± 1.615 0.501 ± 1.036 0.543 ± 0.059 0.963 ± 0.069
MSE 0.513 ± 0.991 0.744 ± 1.477 0.276 ± 0.509 0.527 ± 0.033 0.979 ± 0.077
Binary -0.220 ± 0.523 -0.329 ± 0.768 -0.043 ± 0.145 0.499 ± 0.011 0.969 ± 0.043
LSTM
Sharpe 2.781 ± 2.081* 3.978 ± 3.160* 2.488 ± 1.921* 0.593 ± 0.054* 1.104 ± 0.199*
Ave. Returns 0.961 ± 1.268 1.397 ± 1.926 0.679 ± 1.552 0.547 ± 0.039 0.972 ± 0.118
MSE 0.451 ± 0.526 0.668 ± 0.812 0.184 ± 0.170 0.520 ± 0.026 0.996 ± 0.048
Binary -0.114 ± 2.147 -0.191 ± 3.435 0.227 ± 1.241 0.495 ± 0.077 1.002 ± 0.077

18
Exhibit 11: Cross-Validation Performance – Rescaled to Target Volatility
Downside
E[Return] Vol. MDD
Deviation
Reference
Long Only 0.131 ± 0.142 0.154 ± 0.001 0.104 ± 0.014 0.304 ± 0.113
Sgn(Returns) 0.186 ± 0.184 0.154 ± 0.002 0.101 ± 0.012 0.194 ± 0.126
MACD 0.140 ± 0.166 0.154 ± 0.002 0.105 ± 0.010 0.243 ± 0.129
Linear
Sharpe 0.182 ± 0.273 0.155 ± 0.003 0.105 ± 0.007 0.232 ± 0.175
Ave. Returns 0.127 ± 0.141 0.154 ± 0.003 0.101 ± 0.009 0.318 ± 0.177
MSE 0.170 ± 0.189 0.154 ± 0.003 0.099 ± 0.006* 0.256 ± 0.221
Binary 0.049 ± 0.170 0.155 ± 0.002 0.104 ± 0.013 0.351 ± 0.114
MLP
Sharpe 0.271 ± 0.375 0.154 ± 0.008 0.104 ± 0.000 0.186 ± 0.259
Ave. Returns 0.233 ± 0.270 0.154 ± 0.003 0.101 ± 0.010 0.194 ± 0.277
MSE 0.148 ± 0.178 0.154 ± 0.003 0.100 ± 0.009 0.268 ± 0.262
Binary -0.011 ± 0.117 0.154 ± 0.002 0.102 ± 0.018 0.377 ± 0.221
WaveNet
Sharpe 0.131 ± 0.103 0.154 ± 0.002 0.104 ± 0.009 0.254 ± 0.164
Ave. Returns 0.142 ± 0.196 0.154 ± 0.002 0.103 ± 0.003 0.262 ± 0.204
MSE 0.087 ± 0.150 0.153 ± 0.003* 0.101 ± 0.009 0.307 ± 0.247
Binary -0.030 ± 0.099 0.155 ± 0.001 0.105 ± 0.006 0.485 ± 0.283
LSTM
Sharpe 0.435 ± 0.342* 0.155 ± 0.002 0.108 ± 0.012 0.164 ± 0.077*
Ave. Returns 0.157 ± 0.202 0.153 ± 0.002* 0.102 ± 0.011 0.285 ± 0.196
MSE 0.087 ± 0.091 0.154 ± 0.003 0.100 ± 0.006 0.310 ± 0.130
Binary -0.008 ± 0.332 0.155 ± 0.002 0.100 ± 0.009 0.428 ± 0.495

Fraction of Ave. P
Sharpe Sortino Calmar
+ve Returns Ave. L

Reference
Long Only 0.847 ± 0.915 1.287 ± 1.475 0.445 ± 0.579 0.546 ± 0.025 0.958 ± 0.164
Sgn(Returns) 1.213 ± 1.205 1.856 ± 1.944 1.098 ± 1.658 0.543 ± 0.061 1.028 ± 0.070
MACD 0.911 ± 1.086 1.361 ± 1.733 0.643 ± 0.958 0.532 ± 0.030 1.023 ± 0.074
Linear
Sharpe 1.176 ± 1.772 1.752 ± 2.615 1.060 ± 2.376 0.544 ± 0.042 1.025 ± 0.139
Ave. Returns 0.826 ± 0.914 1.287 ± 1.504 0.471 ± 0.777 0.530 ± 0.009 1.016 ± 0.116
MSE 1.101 ± 1.220 1.729 ± 2.037 0.890 ± 1.787 0.540 ± 0.024 1.022 ± 0.116
Binary 0.321 ± 1.105 0.509 ± 1.720 0.169 ± 0.585 0.506 ± 0.027 1.031 ± 0.127
MLP
Sharpe 1.757 ± 2.405 2.623 ± 3.626 2.091 ± 3.474 0.554 ± 0.063 1.085 ± 0.176
Ave. Returns 1.516 ± 1.764 2.336 ± 2.923 1.771 ± 2.889 0.553 ± 0.043 1.038 ± 0.141
MSE 0.960 ± 1.163 1.510 ± 1.927 0.864 ± 1.999 0.525 ± 0.025 1.059 ± 0.103
Binary -0.071 ± 0.756 -0.140 ± 1.133 0.000 ± 0.372 0.500 ± 0.042 0.991 ± 0.072
WaveNet
Sharpe 0.849 ± 0.663 1.270 ± 1.060 0.575 ± 0.680 0.535 ± 0.022 1.000 ± 0.121
Ave. Returns 0.920 ± 1.271 1.376 ± 1.915 0.738 ± 1.591 0.543 ± 0.059 0.979 ± 0.066
MSE 0.565 ± 0.972 0.854 ± 1.513 0.364 ± 0.665 0.527 ± 0.033 0.986 ± 0.081
Binary -0.196 ± 0.641 -0.298 ± 0.946 -0.044 ± 0.206 0.499 ± 0.011 0.974 ± 0.067
LSTM
Sharpe 2.803 ± 2.195* 4.084 ± 3.469* 2.887 ± 3.030* 0.593 ± 0.054* 1.106 ± 0.216*
Ave. Returns 1.023 ± 1.312 1.564 ± 2.131 0.706 ± 1.440 0.547 ± 0.039 0.980 ± 0.127
MSE 0.563 ± 0.580 0.865 ± 0.901 0.284 ± 0.269 0.520 ± 0.026 1.014 ± 0.016
Binary -0.050 ± 2.152 -0.122 ± 3.381 0.190 ± 1.152 0.495 ± 0.077 1.012 ± 0.048

19
Exhibit 12: Performance Ratios Across Individual Assets

(a) Sharpe Ratio

(b) Sortino Ratio

(c) Calmar Ratio

20
Exhibit 13: Reward vs Risk Across Individual Assets

(a) Expected Returns

(b) Volatility

(c) Downside Deviation

(d) Max. Drawdown


Chapter 5

Learning Time-Dependent Causal


Effects

Publications Included
• B. Lim, A. Alaa, M. van der Schaar. Forecasting Treatment Responses Over
Time Using Marginal Structural Models. Advances in Neural Information Pro-
cessing Systems (NeurIPS), 2018. Weblink: http://papers.nips.cc/paper/
7977-forecasting-treatment-responses-over-time-using-recurrent-marginal-
structural-networks.

106
Forecasting Treatment Responses Over Time Using
Recurrent Marginal Structural Networks

Bryan Lim Ahmed Alaa


Department of Engineering Science Electrical Engineering Department
University of Oxford University of California, Los Angeles
[email protected] [email protected]

Mihaela van der Schaar


University of Oxford
and The Alan Turing Institute
[email protected]

Abstract
Electronic health records provide a rich source of data for machine learning meth-
ods to learn dynamic treatment responses over time. However, any direct estimation
is hampered by the presence of time-dependent confounding, where actions taken
are dependent on time-varying variables related to the outcome of interest. Drawing
inspiration from marginal structural models, a class of methods in epidemiology
which use propensity weighting to adjust for time-dependent confounders, we
introduce the Recurrent Marginal Structural Network - a sequence-to-sequence
architecture for forecasting a patient’s expected response to a series of planned treat-
ments. Using simulations of a state-of-the-art pharmacokinetic-pharmacodynamic
(PK-PD) model of tumor growth [12], we demonstrate the ability of our network
to accurately learn unbiased treatment responses from observational data – even
under changes in the policy of treatment assignments – and performance gains over
benchmarks.

1 Introduction
With the increasing prevalence of electronic health records, there has been much interest in the use
of machine learning to estimate treatment effects directly from observational data [13, 41, 44, 2].
These records, collected over time as part of regular follow-ups, provide a more cost-effective method
to gather insights on the effectiveness of past treatment regimens. While the majority of previous
work focuses on the effects of interventions at a single point in time, observational data also captures
information on complex time-dependent treatment scenarios, such as where the efficacy of treatments
changes over time (e.g. drug resistance in cancer patients [40]), or where patients receive multiple
interventions administered at different points in time (e.g. joint prescriptions of chemotherapy and
radiotherapy [12]). As such, the ability to accurately estimate treatment effects over time would allow
doctors to determine both the treatments to prescribe and the optimal time at which to administer
them.
However, straightforward estimation in observational studies is hampered by the presence of time-
dependent confounders, arising in cases where interventions are contingent on biomarkers whose
value are affected by past treatments. For examples, asthma rescue drugs provide short-term rapid
improvements to lung function measures, but are usually prescribed to patients with reduced lung
function scores. As such, naïve methods can lead to the incorrect conclusion that the medication
reduces lung function scores, contrary to the actual treatment effect [26]. Furthermore, [23] show

32nd Conference on Neural Information Processing Systems (NeurIPS 2018), Montréal, Canada.
that the standard adjustments for causal inference, e.g. stratification, matching and propensity scoring
[16], can introduce bias into the estimation in the presence of time-dependent confounding.
Marginal structural models (MSMs) are a class of methods commonly used in epidemiology to
estimate time-dependent effects of exposure while adjusting for time-dependent confounders [15, 24,
19, 14]. Using the probability of a treatment assignment, conditioned on past exposures and covariate
history, MSMs typically adopt inverse probability of treatment weighting (IPTW) to correct for bias
in standard regression methods [22], re-constructing a ‘pseudo-population’ from the observational
dataset to similar to that of a randomized clinical trial. However, the effectiveness of bias correction
is dependent on a correct specification of the conditional probability of treatment assignment, which
is difficult to do in practice given the complexity of treatment planning. In standard MSMs, IPTWs
are produced using pooled logistic regression, which makes strong assumptions on the form of the
conditional probability distribution. This also requires one separate set of coefficients to be estimated
per time-step and many models to be estimated for long trajectories.
In this paper, we propose a new deep learning model - which we refer to as Recurrent Marginal
Structural Networks - to directly learn time-dependent treatment responses from observational data,
based on in the marginal structural modeling framework. Our key contributions are as follows:
Multi-step Prediction Using Sequence-to-sequence Architecture To forecast treatment re-
sponses at multiple time horizons in the future, we propose a new RNN architecture for multi-step
prediction based on sequence-to-sequence architectures in natural language processing [36]. This
comprises two halves, 1) an encoder RNN which learns representations for the patient’s current clini-
cal state, and 2) a decoder which is initialized using the encoder’s final memory state and computes
forward predictions given the intended treatment assignments. At run time, the R-MSN also allows
for prediction horizons to be flexibly adjusted to match the intended treatment duration, by expanding
or contracting the number of decoder units in the sequence-to-sequence model.
Scenario Analysis for Complex Treatment Regimens Treatment planning in clinical settings is
often based on the interaction of numerous variables - including 1) the desired outcomes for a patient
(e.g. survival improvement or comorbidity risk reduction), 2) the treatments to assign (e.g. binary
interventions or continuous dosages), and 3) the length of treatment affected by both number and
duration of interventions. The R-MSN naturally encapsulates this by using multi-input/output RNNs,
which can be configured to have multiples treatments and targets of different forms (e.g. continuous
or discrete). Different sequences of treatments can also be evaluated using the sequence-to-sequence
architecture of the network. Moreover, given the susceptibility of IPTWs to model misspecification,
the R-MSN uses Long-short Term Memory units (LSTMs) to compute the probabilities required
for propensity weighting. Combining these aspects together, the R-MSN is able to help clinicians
evaluate the projected outcome of a complex treatment scenario – providing timely clinical decision
support and helping them customize a treatment regimen to the patient. A example of scenario
analysis for different cancer treatment regimens is shown in Figure 1, with the expected response of
tumor growth to no treatment, chemotherapy and radiotherapy shown.

Figure 1: Forecasting Tumor Growth Under Multiple Treatment Scenarios

2 Related Works

Given the diversity of literature on causal inference, we focus on works associated with time-
dependent treatment responses and deep learning here, with a wider survey in Appendix A.

2
G-computation and Structural Models. Counterfactual inference under time-dependent con-
founding has been extensively studied in the epidemiology literature, particularly in the seminal
works of Robins [30, 31, 16]. Methods in this area can be categorized into 3 groups: models based
on the G-computation formula, structural nested mean models, and marginal structural models [8].
While all these models provide strong theoretical foundations on the adjustments for time-dependent
confounding, their prediction models are typically based on linear or logistic regression. These
models would be misspecified when either the outcomes or the treatment policy exhibit complex
dependencies on the covariate history.
Potential Outcomes with Longitudinal Data. Bayesian nonparametric models have been pro-
posed to estimate the effects of both single [32, 33, 43, 34] and joint treatment assignments [35]
over time. These methods use Gaussian processes (GPs) to model the baseline progression, which
can estimate the treatment effects at multiple points in the future. However, some limitations do
exist. Firstly, to aid in calibration, most Bayesian methods make strong assumptions on model
structure - such as 1) independent baseline progression and treatment response components [43, 35],
and 2) the lack of heterogeneous effects, by either omitting baseline covariates (e.g. genetic or
demographic information) [34, 33] or incorporating them as linear components [43, 35]. Recurrent
neural networks (RNNs) avoid the need for any explicit model specifications, with the networks
learning these relationships directly from the data. Secondly, inference with Bayesian models can
be computationally complex, making them difficult to scale. This arises from the use of Markov
Chain-Monte Carlo sampling for g-computation, and the use of sparse GPs that have at least O(N M 2 )
complexity, where N and M are the number of observations and inducing points respectively [39].
From this perspective, RNNs have the benefit of scalability and update their internal states with new
observations as they arrive. Lastly, apart from [35] which we evaluate in Section 5, existing models
do not consider treatment responses for combined interventions and multiple targets. This is handled
naturally in our network by using multi-input/multi-output RNN architectures.
Deep Learning for Causal Inference. Deep learning has also been used to estimate individualized
treatment effects for a single intervention at a fixed time, using instrumental variable approaches [13],
generative adversarial networks [44] and multi-task architectures [3]. To the best of our knowledge,
ours is the first deep learning method for time-dependent effects and establishes a framework to use
existing RNN architectures for treatment response estimation.

3 Problem Definition

Let Yt,i = [Yt,i (1), . . . , Yt,i (Ωy )] be a vector of Ωy observed outcomes for patient i at time t, At,i =
[At,i (1), . . . , At,i (Ωa )] a vector of actual treatment administered, Lt,i = [Lt,i (1), . . . , Lt,i (Ωl )] time-
dependent covariates and Xi = [Xi (1), . . . , Xi (Ωv )] patient-specific static features. For notational
simplicity, we will omit the subscript i going forward unless explicitly required.
Treatment Responses Over Time Determining an individual’s response to a prescribed treatment
can be characterized as learning a function g(.) for the expected outcomes over a prediction horizon
τ , given an intended course of treatment and past observations, i.e.:
 
E Yt+τ |a(t, τ − 1), H̄t = g(τ, a(t, τ − 1), H̄t ) (1)

where g(.) represents a generic, possibly non-linear, function, a(t, τ − 1) = (at , . . . at+τ −1 ) is an
intended sequence of treatments ak from the current time until just before the outcome is observed,
and H̄t = (L̄t , Āt−1 , X) is the patient’s history with covariates L̄t = (L1 , . . . , Lt ) and actions
Āt−1 = (A1 , . . . At−1 ).
Inverse Probability of Treatment Weighting Inverse probability of treatment weighting, ex-
tensively studied in marginal structural modeling to adjust for time-dependent confounding
[22, 16, 15, 24, 26], with extensions to joint treatment assignments [19], censored observations
[14] and continuous dosages [10]. We list the key results for our problem below, with a more
thorough discussion in Appendix B.
The stabilized weights for joint treatment assignments [21] can be expressed as:
t+τ
Y f (An |Ān−1 ) t+τ Y QΩa f (An (k)|Ān−1 )
SW(t, τ ) = = k=1
QΩa (2)
n=t
f (A n | H̄n ) n=t k=1 f (An (k)|H̄n )

3
where f (.) is the probability mass function for discrete treatment applications, or the probability
density function when continuous dosages are used [10]. We also note that H̄n contains both past
treatments Ān−1 and potential confounders L̄n . To account for censoring, we used the additional
stabilized weights below:
t+τ
Y f (Cn = 0|T > n, Ān−1 )
SW ∗ (t, τ ) = (3)
n=t
f (Cn = 0|T > n, L̄n−1 , Ān−1 , X)
where Cn = 1 denotes right censoring of the trajectory, and T is the time at which censoring occurs.
We also adopt the additional steps for stabilization proposed in [42], truncating stabilized weights at
their 1st and 99th percentile values,and normalizing weights by their mean for a fixed prediction
horizon, i.e. SW ˜ = SWi (t, τ )/ PI PTi SWi (t, τ )/N where I is the total number of
i=1 t=1
patients, Ti is the length of the patient’s trajectory and N the total number of observations. Stabilized
weights are then used to weight the loss contributions of each training observation, expressed in
squared-errors terms below for continuous predictions:

e(i, t, τ ) = SW ˜ i (t, τ − 1) × kYt+τ,i − g(τ, a(t, τ − 1), H̄t )k2
˜ i (t, τ − 1) × SW (4)

4 Recurrent Marginal Structural Networks


An MSM can be subdivided into two submodels, one modeling the IPTWs and the other estimating
the treatment response itself. Adopting this framework, we use two sets of deep neural networks to
build a Recurrent Marginal Structural Network (R-MSN) - 1) a set propensity networks to compute
treatment probabilities used for IPTW, and 2) a prediction network used to determine the treatment
response for a given set of planned interventions. Additional details on the algorithm can be found in
Appendix E, with the source code uploaded onto GitHub1 .

4.1 Propensity Networks


From Equations 2 and 3, we can see that 4 key probability functions are required to calculate the
stabilized weights. In all instances, probabilities are conditioned on the history of past observations
(Ān−1 and H̄n ), making RNNs natural candidates to learn these functions.
Each probability function is parameterized with a different
 LSTM – collectively referred to as
propensity networks – with action probabilities f Ān | . generated jointly by a set of multi-target
LSTMs and censoring probabilities f (Cn = 0| . ) by single output LSTMs. This also accounts
for possible correlations between treatment assignments, for instance in treatment regimens where
complementary drugs are prescribed together to combat different aspects of the same disease.
The flexibility of RNN architectures also allows for the modeling of treatment assignments with
different forms. In simple cases with discrete treatment assignments, a standard LSTM with a sigmoid
output layer can be used for binary treatment probabilities or a softmax layer for categorical ones.
More complex architectures, such as variational RNNs [6], can be used to compute probabilities
when treatments map to continuous dosages. To calculate the binary probabilities in the experiments
in Section 5, LSTMs were fitted with tanh state activations and sigmoid outputs.

4.2 Prediction Network


The prediction network focuses on forecasting the treatment response of a patient, with time-
dependent confounding accounted for using IPTWs from the propensity networks. Although standard
RNNs can be used for one-step-ahead forecasts, actual treatments plans can be considerably more
complex, with varying durations and number of interventions depending on the condition of the
patient. To remove any restrictions on the prediction horizon or number of planned interventions, we
propose the sequence-to-sequence architecture depicted in Figure 4.2. One key difference between
our model and standard sequence-to-sequence (e.g.[36]) is that the last unit of the encoder is also used
in making predictions for the first time step, in addition to the decoder units at further horizons. This
allows the R-MSN to use all available information in making predictions, including the covariates
1
https://github.com/sjblim/rmsn_nips_2018

4
Figure 2: R-MSN Architecture for Multi-step Treatment Response Prediction

available at the current time step t. For the continuous predictions in Section 5, we used Exponential
Linear Unit (ELU [7]) state activations and a linear output layer.
Encoder The goal of the encoder is to learn good representations for the patient’s current clinical
state, and we do so with a standard LSTM that makes one-step-ahead predictions of the outcome
(Ŷt+1 ) given observations of covariates and actual treatments. At the current follow-up time t, the
encoder is also used in forecasting the expected response at t + 1, as the latest covariate measurements
Lt are available to be fed into the LSTM along with the first planned treatment assignment.
Decoder While multi-step prediction can be performed by recursively feeding outputs into the
inputs at the next time step, this would require output predictions for all covariates, with a high
degree of accuracy to reduce error propagation through the network. Given that often only a small
subset treatment outcomes are of interest, it would be desirable to forecast treatment responses on the
basis of planned future actions alone. As such, the purpose of the decoder is to propagate the encoder
representation forwards in time - using only the proposed treatment assignments and avoiding the
need to forecast input covariates. This is achieved by training another LSTM that accepts only actions
as inputs, but initializing the internal memory state of the first LSTM in the decoder sequence (zt )
using encoder representations. To allow for different state sizes in the encoder and decoder, encoder
internal states (ht ) are passed through a single network layer with ELU activations, i.e. the memory
adapter, before being initializing the decoder. As the network is made up of LSTM units, the internal
states here refer to the concatenation of the cell and hidden states [17] of the LSTM.

4.3 Training Procedure

The training procedure for R-MSNs can be subdivided into the 3 training steps shown in Figure 3 -
starting with the propensity networks, followed by the encoder, and ending with the decoder.

(a) Step 1: Propensity Network Training (b) Step 2: Encoder Training

(c) Step 3: Decoder Training

Figure 3: Training Procedure for R-MSNs

5
Step 1: Propensity Network Training From Figure 3(a), each propensity network is first trained
to estimate the probability of the treatment assigned at each time step, which is combined to compute
SW(t, 0) and SW ∗ (t, 0) at each time step. StabilizedQ weights for longer horizons can then be
τ
obtained from their cumulative product, i.e. SW(t, τ ) = j=0 SW(t + j, 0). For tests in Section 5,
propensity networks were trained using standard binary cross entropy loss, with treatment assignments
and censoring treated as binary observations.
Step 2: Encoder Training Next, decoder and encoder training was divided into separate steps -
accelerating learning by first training the encoder to learn representations of the patient’s clinical
state and then using the decoder to extrapolate them according to the intended treatment plan. As
such, the encoder was trained to forecast standard one-step-ahead treatment response according to the
structure in Figure 3(b), using all available information on treatments and covariates until the current
time step. Upon completion, the encoder was used to perform a feed-forward pass over the training
and validation data, extracting the internal states ht for the final training step. As tests in Section 5
were performed for continuous outcomes, we express the loss function for the encoder as a weighted
mean-squared error loss (Lencoder in Equation 5), although we note that this approach is compatible
with other loss functions, e.g. cross entropy for discrete outcomes.
Step 3: Decoder Training Finally, the decoder and memory adapter were trained together
based on the format in Figure 3(c). For a given patient, observations were batched into shorter
sequences of up to τmax steps, such that each sequence commencing at time t is made up of
[ht , {At+1 , . . . , At+τmax −1 }, {Yt+2 , . . . , Yt+τmax }]. These were compiled for all patient-times
and randomly grouped into minibatches to be used for backpropagation through time. For continuous
predictions, the loss function for the decoder is (Ldecoder ) can also be found in Equation 5.
Ti
I X I X i −t,τmax )
Ti min(TX
X X
Lencoder = e(i, t, 1) Ldecoder = e(i, t, τ ) (5)
i=1 t=1 i=1 t=1 τ =2

5 Experiments With Cancer Growth Simulation Model

5.1 Simulation Details


As confounding effects in real-world datasets are unknown a priori, methods for treatment response
estimation are often evaluated using data simulations, where treatment application policies are
explicitly modeled [34, 33, 35]. To ensure that our tests are fully reproducible and realistic from
a medical perspective, we adopt the pharmacokinetic-pharmacodynamic (PK-PD) model of [12]
- the state-of-the-art in treatment response modeling for non-small cell lung patients. The model
features key characteristics present in actual lung cancer treatments, such as combined effects of
chemo- and radiotherapy, cell repopulation after treatment, death/recovery of patients, and different
staring distributions of tumor sizes based on the stage of cancer at diagnosis. On the whole, PK-
PD models allow clinicians to explore hypotheses around dose-response relationships and propose
optimal treatment schedules [5, 29, 11, 9, 1]. While we refer readers to [12] for the finer details of the
model, such as specific priors used, we examine the overall structure of the model below to illustrate
treatment-response relationships and how time-dependent confounding is introduced.
PK-PD Model for Tumor Dynamics We use a discrete-time model for tumor volume V (t), where
t is the number of days since diagnosis:
 
K
V (t) = 1 + ρ log( ) − βc C(t) − (αd(t) + βd(t)2 ) + et V (t − 1)
V (t − 1) | {z } | {z } |{z} (6)
| {z } Chemotherapy Radiation Noise
Tumor Growth

where ρ, K, βc , α, β are model parameters sampled for each patient according to prior distributions
in [12]. A Gaussian noise term et ∼ N (0, 0.012 ) was added to account for randomness in the growth
of the tumor. d(t) is the dose of radiation applied at t, while drug concentration C(t) is modeled
according to an exponential decay with a half life of 1 day, i.e.:
C(t) = C̃(t) + C(t − 1)/2 (7)
where C̃(t) is an new continuous dose of chemotherapy drugs applied at time t. To account for
heterogeneous effects, we added static features to the simulation model by randomly subclassing

6
patients into 3 different groups, with each patient having a group label Si ∈ {1, 2, 3}. This represents
specific characteristics which affect with patient’s response to chemotherapy and radiotherapy (e.g.
by genetic factors [4]), which augment the prior means of βc and α according to:
 
1.1µβc , if Si = 3 1.1µα , if Si = 1
µ0βc (i) = µ0α (i) = (8)
µβc , otherwise µα , otherwise

where µ∗ are the mean parameters of [12], and µ0∗ (i) those used to simulate patient i. We note that
the value of β is set in relation to α, i.e. α/β = 10, and would also be adjusted accordingly by Si .
Censoring Mechanisms Patient censoring is incorporated by modeling 1) death when tumor
diameters reach Dmax = 13 cm (or a volume of Vmax = 1150 cm3 assuming perfectly spherical
tumors), 2) recovery determined by a Bernoulli process with recovery probability pt = exp(−Vt ),
and 3) termination of observations after 60 days (administrative censoring).
Treatment Assignment Policy To introduce time-dependent confounders, we assume that
chemotherapy prescriptions Ac (t) ∈ {0, 1} and radiotherapy prescriptions Ad (t) ∈ {0, 1} are
Bernoulli random variables, with probabilities pc (t) and pd (t) respectively that are a functions of the
tumor diameter:
   
γc γd
pc (t) = σ (D̄(t) − θc ) pd (t) = σ (D̄(t) − θd ) (9)
Dmax Dmax
where D̄(t) is the average tumor diameter over the last 15 days, σ(.) is the sigmoid activation function,
and θ∗ and γ∗ are constant parameters. θ∗ is fixed such that θc = θd = Dmax /2, giving the model
a 0.5 probability of treatment application exists when the tumor is half its maximum size. When
treatments are applied, i.e. Ac (t) or Ad (t) is 1, chemotherapy is assumed to be administered in
5.0 mg/m3 doses of Vinblastine, and radiotherapy in 2.0 Gy fractions. γ also controls the degree of
time-dependent confounding - starting with no confounding at γ = 0, as treatment assignments are
independent of the response variable, and an increase as γ becomes larger.

5.2 Benchmarks

We evaluate the performance of R-MSNs against MSMs and Bayesian nonparametric models,
focusing on its effectiveness in estimating unbiased treatment responses and its multi-step prediction
performance. An overview of the models tested is summarized below:
Standard Marginal Structural Models (MSM) For the MSMs used in our investigations, we
adopt similar approximations to [19, 14], encoding historical actions via cumulative sum of applied
Pt−1
treatments, e.g. cum(āc (t − 1)) = k=1 ac (k), and covariate history using the previous observed
value V (t − 1). The exact forms of the propensity and prediction models are in Appendix D.
Bayesian Treatment Response Curves (BTRC) We also benchmark our performance against the
model of [35] - the state-of-the-art in forecasting multistep treatment responses for joint therapies
with multiple outcomes. Given that the simulation model only has one target outcome, we also
consider a simpler variant of the model without “shared" components, denoting this as the reduced
BTRC (R-BTRC) model. This reduced parametrization was found to improve convergence during
training, and additional details on calibration can be found in Appendix G.
Recurrent Marginal Structural Networks (R-MSN) R-MSNs were designed according to the
description in Section 4, with full details on training and hyperparameter in Appendix F. To evaluate
the effectiveness of the propensity networks, we also trained predictions networks using the IPTWs
from the MSM, including this as an additional benchmark in Section 5.3 (Seq2Seq + Logistic).

5.3 Performance Evaluations

Time-Dependent Confounding Adjustments To investigate how well models learn unbiased


treatment responses from observational data, we trained all models on simulations with γc = γd = 10
(biased policy) and examine the root-mean-squared errors (RMSEs) of one-step-ahead predictions as
γ∗ is reduced. Both γ∗ parameters were set to be equal in this section for simplicity, i.e. γc = γd = γ.
Using the simulation model in Section 5.1, we simulated 10,000 paths to be used for model training,
1,000 for validation data used in hyperparameter optimization, and another 1,000 for out-of-sample

7
Figure 4: Normalized RMSEs for One-Step-Ahead Predictions

testing. For linear and MSM models, which do not have hyperparameters to optimized, we combined
both training and validation datasets for model calibration.
Figure 4 shows the RMSE values of various models at different values of γ, with RMSEs normalized
with Vmax and reported in percentage terms. Here, we focus on the main comparisons of interest
– 1) linear models to provide a baseline on performance, 2) linear vs MSMs to evaluate traditional
methods for IPTWs, 3) Seq2Seq + logistic IPTWs vs MSMs for the benefits of the Seq2Seq model,
4) R-MSN vs Seq2Seq + logistic to determine the improvements of our model and RNN-estimated
IPTWs, and 5) BTRC/R-BTRC to benchmark against state-of-the-art methods. Additional results are
also documented in Appendix C for reference.
From the graph, R-MSNs displayed the lowest RMSEs across all values of γ, decreasing slightly
from a normalized RMSE of 1.02% at γ = 10 to 0.92% at γ = 0. Focusing on RMSEs at γ = 0,
R-MSNs improve MSMs by 80.9% and R-BTCs by 66.1%, demonstrating its effectiveness in learning
unbiased treatment responses from confounded data. The propensity networks also improve unbiased
treatment estimates by 78.7% (R-MSN vs. Seq2Seq + Logistic), indicating the benefits of more
flexible models for IPTW estimation. While the IPTWs of MSMs do provide small gains for linear
models, linear models still exhibit the largest unbiased RMSE across all benchmarks - highlighting
the limitations of linear models in estimating complex treatment responses. Bayesian models also
perform consistently across γ, with normalized RMSEs for R-BTRC decreasing from 2.09% to 1.91%
across γ = 0 to 10, but were also observed to slightly underperform linear models on the training
data itself. Part of this can potentially be attributed to model misspecification in the BTRC, which
assumes that treatment responses are linear time-invariant and independent of the baseline progression.
The differences in modeling assumptions can be seen from Equation 6, where chemotherapy and
radiotherapy contributions are modeled as multiplicative with V (t). This highlights the benefits of
the data-driven nature of the R-MSN, which can flexibly learn treatment response models of different
types.
Multi-step Prediction Performance To evaluate the benefits of the sequence-to-sequence archi-
tecture, we report the normalized RMSEs for multi-step prediction in Table 1, using the best model of
each category (R-MSN, MSM and R-BTRC). Once again, the R-MSN outperforms benchmarks for
all timesteps, beating MSMs by 61% on the training policy and 95% for the unbiased one. While the
R-BTRC does show improvements over MSMs for the unbiased treatment response, we also observe
a slight underperformance versus MSMs on the training policy itself, highlighting the advantages of
R-MSNs.

6 Conclusions

This paper introduces Recurrent Marginal Structural Networks - a novel learning approach for
predicting unbiased treatment responses over time, grounded in the framework of marginal structural
models. Networks are subdivided into two parts, a set of propensity networks to accurately compute
the IPTWs, and a sequence-to-sequence architecture to predict responses using only a planned
sequence of future actions. Using tests on a medically realistic simulation model, the R-MSN
demonstrated performance improvements over traditional methods in epidemiology and the state-of-
the-art models for joint treatment response prediction over multiple timesteps.

8
Table 1: Normalized RMSE for Various Prediction Horizons τ
Ave. % Decrease
τ 1 2 3 4 5 in RMSE vs MSMs
Training MSM 1.67% 2.51% 3.12% 3.64% 4.09% -
Policy R-BTRC 2.09% 2.85% 3.50% 4.07% 4.58% -32% (↑ RMSE)
(γc = 10, γd = 10) R-MSN 1.02% 1.80% 1.90% 2.11% 2.46% +61%
Unbiased MSM 4.84% 5.29% 5.51% 5.65% 5.84% -
Assignment R-BTRC 1.91% 2.74% 3.34% 3.75% 4.08% +66%
(γc = 0, γd = 0) R-MSN 0.92% 1.38% 1.30% 1.22% 1.14% +95%
Unbiased MSM 3.85% 4.03% 4.32% 4.60% 4.91% -
Radiotherapy R-BTRC 1.74% 1.68% 2.14% 2.54% 2.91% +74%
(γc = 10, γd = 0) R-MSN 1.08% 1.66% 1.83% 1.98% 2.14% +84%
Unbiased MSM 1.84% 2.65% 3.09% 3.44% 3.83% -
Chemotherapy R-BTRC 1.16% 2.45% 2.97% 3.34% 3.64% +20%
(γc = 0, γd = 10) R-MSN 0.65% 1.13% 1.05% 1.17% 1.31% +87%

Acknowledgments
This research was supported by the Oxford-Man Institute of Quantitative Finance, the US Office of
Naval Research (ONR), and the Alan Turing Institute.

9
References
[1] Optimizing drug regimens in cancer chemotherapy: a simulation study using a pk–pd model. Computers in
Biology and Medicine, 31(3):157 – 172, 2001. Goal-Oriented Model-Based drug Regimens.
[2] Ahmed M. Alaa and Mihaela van der Schaar. Bayesian inference of individualized treatment effects
using multi-task gaussian processes. In Proceedings of the thirty-first Conference on Neural Information
Processing Systems, (NIPS), 2017.
[3] Ahmed M. Alaa, Michael Weisz, and Mihaela van der Schaar. Deep counterfactual networks with propensity
dropout. In Proceedings of the 34th International Conference on Machine Learning (ICML), 2017.
[4] H. Bartsch, H. Dally, O. Popanda, A. Risch, and P. Schmezer. Genetic risk profiles for cancer susceptibility
and therapy response. Recent Results Cancer Res., 174:19–36, 2007.
[5] Letizia Carrara, Silvia Maria Lavezzi, Elisa Borella, Giuseppe De Nicolao, Paolo Magni, and Italo Poggesi.
Current mathematical models for cancer drug discovery. Expert Opinion on Drug Discovery, 12(8):785–799,
2017.
[6] Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron Courville, and Yoshua Bengio. A
recurrent latent variable model for sequential data. In Proceedings of the 28th International Conference on
Neural Information Processing Systems - Volume 2, NIPS’15, pages 2980–2988, Cambridge, MA, USA,
2015.
[7] Djork-Arné Clevert, Thomas Unterthiner, and Sepp Hochreiter. Fast and accurate deep network learning by
exponential linear units (ELUs). CoRR, abs/1511.07289, 2015.
[8] R.M. Daniel, S.N. Cousens, B.L. De Stavola, M. G. Kenward, and J. A. C. Sterne. Methods for dealing with
time-dependent confounding. Statistics in Medicine, 32(9):1584–1618, 2012.
[9] Mould DR, Walz A-C, Lave T, Gibbs JP, and Frame B. Developing exposure/response models for anticancer
drug treatment: Special considerations. CPT: Pharmacometrics & Systems Pharmacology, 4(1):12–27.
[10] Peter H. Egger and Maximilian von Ehrlich. Generalized propensity scores for multiple continuous
treatment variables. Economics Letters, 119(1):32 – 34, 2013.
[11] M. J. Eigenmann, N. Frances, T. Lavé, and A.-C. Walz. Pkpd modeling of acquired resistance to anti-cancer
drug treatment. Journal of Pharmacokinetics and Pharmacodynamics, 44(6):617–630, 2017.
[12] Changran Geng, Harald Paganetti, and Clemens Grassberger. Prediction of treatment response for combined
chemo- and radiation therapy for non-small cell lung cancer patients using a bio-mathematical model.
Scientific Reports, 7, 2017.
[13] Jason Hartford, Greg Lewis, Kevin Leyton-Brown, and Matt Taddy. Deep IV: A flexible approach for
counterfactual prediction. In Proceedings of the 34th International Conference on Machine Learning
(ICML), 2017.
[14] Miguel A. Hernan, Babette Brumback, and James M. Robins. Marginal structural models to estimate
the joint causal effect of nonrandomized treatments. Journal of the American Statistical Association,
96(454):440–448, 2001.
[15] Miguel A. Hernan and James M. Robins. Marginal structural models to estimate the causal effect of
zidovudine on the survival of hiv-positive men. Epidemiology, 359:561–570, 2000.
[16] MA Hernán and JM Robins. Causal Inference. Chapman & Hall/CRC, 2018.
[17] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural Comput., 9(8):1735–1780,
November 1997.
[18] William Hoiles and Mihaela van der Schaar. A non-parametric learning method for confidently estimating
patient’s clinical state and dynamics. In Proceedings of the twenty-ninth Conference on Neural Information
Processing Systems, (NIPS), 2016.
[19] Chanelle J. Howe, Stephen R. Cole, Shruti H. Mehta, and Gregory D. Kirk. Estimating the effects of
multiple time-varying exposures using joint marginal structural models: alcohol consumption, injection
drug use, and hiv acquisition. Epidemiology, 23(4):574–582, 2012.
[20] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In International
Conference on Learning Representations (ICLR), 2015.
[21] Clovis Lusivika-Nzinga, Hana Selinger-Leneman, Sophie Grabar, Dominique Costagliola, and Fabrice
Carrat. Performance of the marginal structural cox model for estimating individual and joined effects of
treatments given in combination. BMC Medical Research Methodology, 17(1):160, Dec 2017.
[22] Mohammad Ali Mansournia, Mahyar Etminan, Goodarz Danaei, Jay S Kaufman, and Gary Collins. A
primer on inverse probability of treatment weighting and marginal structural models. Emerging Adulthood,
4(1):40–59, 2016.

10
[23] Mohammad Ali Mansournia, Mahyar Etminan, Goodarz Danaei, Jay S Kaufman, and Gary Collins.
Handling time varying confounding in observational research. BMJ, 359, 2017.
[24] Mohammad Alia Mansournia, Goodarzc Danaei, Mohammad Hosseind Forouzanfar, Mahmoodb Mah-
moodi, Mohsene Jamali, Nasrinf Mansournia, and Kazemb Mohammad. Effect of physical activity on
functional performance and knee pain in patients with osteoarthritis : Analysis with marginal structural
models. Epidemiology, 23(4):631–640, 2012.
[25] Charles E McCulloch and Shayle R. Searle. Generalized, Linear and Mixed Models. Wiley, New York,
2001.
[26] Kathleen M. Mortimer, Romain Neugebauer, Mark van der Laan, and Ira B. Tager. An application of model-
fitting procedures for marginal structural models. American Journal of Epidemiology, 162(4):382–388,
2005.
[27] Mihaela van der Schaar Onur Atan, James Jordon. Deep-treat: Learning optimal personalized treatments
from observational data using neural networks. In AAAI, 2018.
[28] Mihaela van der Schaar Onur Atan, William Zame. Constructing effective personalized policies using
counterfactual inference from biased data sets with many features. In Machine Learning, 2018.
[29] Kyungsoo Park. A review of modeling approaches to predict drug response in clinical oncology. Yonsei
Medical Journal, 58(1):1–8, 2017.
[30] Thomas S. Richardson and Andrea Rotnitzky. Causal etiology of the research of james m. robins. Statistical
Science, 29(4):459–484, 2014.
[31] James M. Robins, Miguel Ángel Hernán, and Babette Brumback. Marginal structural models and causal
inference in epidemiology. Epidemiology, 11(5):550–560, 2000.
[32] Jason Roy, Kirsten J. Lum, and Michael J. Daniels. A bayesian nonparametric approach to marginal
structural models for point treatments and a continuous or survival outcome. Biostatistics, 18(1):32–47,
2017.
[33] Peter Schulam and Suchi Saria. Reliable decision support using counterfactual models. In Proceedings of
the thirty-first Conference on Neural Information Processing Systems, (NIPS), 2017.
[34] Ricardo Silva. Observational-interventional priors for dose-response learning. In Proceedings of the
Thirtieth Conference on Neural Information Processing Systems, (NIPS), 2016.
[35] Hossein Soleimani, Adarsh Subbaswamy, and Suchi Saria. Treatment-response models for counterfactual
reasoning with continuous-time, continuous-valued interventions. In Proceedings of the Thirty-Third
Conference on Uncertainty in Artificial Intelligence (UAI), 2017.
[36] Ilya Sutskever, Oriol Vinyals, and Quoc V. Le. Sequence to sequence learning with neural networks. In
Proceedings of the twenty-seventh Conference on Neural Information Processing Systems, (NIPS), 2014.
[37] Adith Swaminathan and Thorsten Joachims. Batch learning from logged bandit feedback through counter-
factual risk minimization. Journal of Machine Learning Research, 16:1731–1755, 2015.
[38] Adith Swaminathan and Thorsten Joachims. Counterfactual risk minimization: Learning from logged
bandit feedback. CoRR, abs/1502.02362, 2015.
[39] Michalis K. Titsias. Variational learning of inducing variables in sparse gaussian processes. In Proceedings
of the twelth International Conference on Artificial Intelligence and Statistic, (AISTATS), 2009.
[40] Panagiotis J. Vlachostergios and Bishoy M. Faltas. Treatment resistance in urothelial carcinoma: an
evolutionary perspective. Nature Review Clincal Oncology, 2018.
[41] Stefan Wager and Susan Athey. Estimation and inference of heterogeneous treatment effects using random
forests. Journal of the American Statistical Association, 2017.
[42] Yongling Xiao, Michal Abrahamowicz, and Erica Moodie. Accuracy of conventional and marginal
structural cox model estimators: A simulation study. The International Journal of Biostatistics, 6(2), 2010.
[43] Yanbo Xu, Yanxun Xu, and Suchi Saria. A non-parametric bayesian approach for estimating treatment-
response curves from sparse time series. In Proceedings of the 1st Machine Learning for Healthcare
Conference (MLHC), 2016.
[44] Jinsung Yoon, James Jordon, and Mihaela van der Schaar. GANITE: Estimation of individualized treatment
effects using generative adversarial nets. In International Conference on Learning Representations (ICLR),
2018.

11
Appendix
A Extended Related Works

Potential Outcomes with Cross-sectional Data. A simpler instantiation of the problem is to


estimate the effect of a treatment applied to subjects in a (static) cross-sectional dataset. This problem
has recently attracted a lot of attention in the machine learning community, and various interesting
ideas were proposed to account for selection bias [3, 41, 44]. Unfortunately, most of these works cast
the treatment effect estimation problem as one of learning under "covariate shift", where the goal is
to learn a model for the outcomes that generalizes well to a population were treatments are randomly
assigned to the subjects. Because of the sequential nature of the treatment assignment process in our
setup, estimating treatment effects under time-dependent confounding cannot be similarly framed as
a covariate shift problem, and hence the ideas developed in those works cannot be straightforwardly
applied to our setup.
Off-policy Evaluation. A closely related problem in the area of reinforcement learning is the
problem of off-policy evaluation using retrospective observational data, also known as "logged bandit
feedback" [38, 37, 18, 27, 28]. In this problem, the goal is to use sequences of states, actions and
rewards generated by a decision-maker that operates under an unknown policy in order to estimate
the expected reward of a given policy. In our setting, we focus on estimating a trajectory of outcomes
given an application of a treatment (or a sequence of treatments) rather than estimating the average
reward of a policy, and hence the "counterfactual risk minimization" framework in [37] would not
result in optimal estimates in our setup. However, our learning model – with a different objective
function– can be applied for the problem of off-policy evaluation.

B Background on Marginal Structural Models

In this section, we summarize the key relevant points from the seminal paper of Robins [31]. Without
loss of generality, we consider the case of univariate treatments, response variables and baseline
covariates here for simplicity.
Marginal structural models are typically considered in the context of follow-up studies, for example
in patients with HIV [31]. Time in the study is typically measured in relation to a fixed starting point,
such as the first follow-up date or time of diagnosis (i.e. t = 1). In such settings, marginal structural
models are used to measure the average treatment effect conditioned on a series of potential actions
and baseline covariate V taken at the start of the study, expressed in the form:

E [Yτ |a1 , . . . , aτ , V ] = r(a1 , . . . , aτ , X; Θ) (10)

where r(.) is a generic, typically linear, function with parameters Θ.


Time-Dependent Confounding. A full description of time-varying confounding can be found in
[23], with formal definitions in [14]. Time-dependent confounding in observational studies arises
as confounders have values which change over time - for example in cases where treatments are
moderated based on the patient’s response. A causal graph for 2-step study can be found in Figure 5,
where U denotes unmeasured factors. Note that U0 , U1 do not have arrows to actions assignments,
reflecting the assumption of no unmeasured confounding.
Inverse Probability of Treatment Weighting From [31], under assumptions of no unmeasured
confounding, positivity, and correct model specification, the stabilized IPTWs can be expressed as:
τ
Y f (An |Ān−1 )
SW (τ ) = (11)
n=0
f (A n |Ān−1 , L̄n , X)

Noting that V is defined to be a subset of L0 in [31]. Informally, they note the denominator to be
conditional probability of a treatment assignment given past observations of treatment assignments
and covariates and the numerator being that of treatment assignments alone, with the stabilized
weights representing the incremental adjustment between the two.

12
Figure 5: Causal Graph of Time-dependent Confounding for 2-step Study

In real clinical settings, it is often desirable to determine the treatment response in relation to the
current follow up time, given past information. As such, we consider trajectories in relation to the
last follow-up time t, retaining the form of the stabilized weights of the MSM and using all past
observations, i.e.
t+τ
Y f (An |Ān−1 )
SW (t, τ ) = (12)
n=t
f (A n |Ān−1 , L̄n , X)

C Additional Results for Experiments with Cancer Growth Simulation

Table 2 documents the full list of comparison for one-step-ahead predictions when tested for various
γ, using different combinations of prediction and IPTW models.

γ= 0 1 2 3 4 5
Linear (No IPTWs) 5.55% 4.81% 4.09% 3.44% 2.86% 2.42%
MSM 4.84% 4.19% 3.56% 3.00% 2.51% 2.15%
MSM (LSTM IPTWs) 4.26% 3.68% 3.13% 2.64% 2.22% 1.95%
Seq2Seq (No IPTWs) 1.52% 1.39% 1.28% 1.23% 1.17% 1.17%
Seq2Seq (Logistic IPTWs) 4.34% 3.28% 2.42% 1.83% 1.43% 1.23%
R-MSN 0.92% 0.89% 0.85% 0.84% 0.79% 0.84%
BTRC 2.73% 2.59% 2.42% 2.28% 2.11% 2.07%
R-BTRC 1.91% 1.81% 1.72% 1.69% 1.63% 1.71%
γ= 6 7 8 9 10
Linear (No IPTWs) 2.09% 1.80% 1.70% 1.65% 1.66%
MSM 1.90% 1.68% 1.64% 1.64% 1.67%
MSM (LSTM IPTWs) 1.77% 1.61% 1.62% 1.64% 1.69%
Seq2Seq (No IPTWs) 1.13% 1.07% 1.07% 1.09% 1.08%
Seq2Seq (Logistic IPTW) 1.12% 1.04% 1.04% 1.05% 1.04%
R-MSN 0.88% 0.88% 0.94% 1.00% 1.02%
BTRC 2.05% 2.01% 2.10% 2.16% 2.19%
R-BTRC 1.77% 1.79% 1.93% 2.02% 2.09%
Table 2: One-step-ahead Prediction Performance for Models Calibrated on γ = 10

13
D Marginal Structural Models for Cancer Simulation
The probabilities required for the IPTWs of the standard MSM in Section 5.2 can be described using
logistic regression models with equations below:

t
X t
X
(k) (k) 
f (At (k)|Āt ) = σ ω1 ( Āc (n − 1)) + ω2 ( Ād (n − 1)) (13)
n=0 n=0

t
X t
X
(k) (k)
f (At (k)|H̄t ) = σ ω5 ( Āc (n − 1)) + ω6 (Ād (n − 1))
n=0 n=0 (14)
(k) (k) (k) 
+ω7 V (t) + ω8 V (t − 1) + ω9 S

t
X t
X 
f (Ct = 0|T > n, Ān−1 ) = σ ω10 ( Āc (n − 1)) + ω11 (Ād (n − 1)) (15)
n=0 n=0

t
X t
X
f (Ct = 0|T > n, L̄n−1 , Ān−1 , X) = σ ω12 ( Āc (n − 1)) + ω13 ( Ād (n − 1))
n=0 n=0
(16)

+ω14 V (t − 1) + ω15 S
where σ(.) the sigmoid function and ω∗ are regression coefficients.
The regression model for prediction is given by:
t
X t
X
g(τ, a(t, τ − 1), H̄t ) = β1 ( Āc (n − 1)) + β2 ( Ād (n − 1))
n=0 n=0
(17)
+β3 V (t) + β4 V (t − 1) + β5 S

E Algorithm Description for R-MSNs


To provide additional clarity on the relationship between the propensity networks and the Seq2Seq
model, the pseudocode in Algorithm 1 describes the training process mentioned in Section 4.3.

We first define function r(.) .; θ (.) to be RNN outputs given a vector of weights and hyperparameters
θ (.) . We refer the reader to Section 3 for more information on the functions in the MSM framework
approximated by RNNs.
Propensity Networks Components of the propensity networks are used to compute the IPTWs
SW(t, τ ) and SW ∗ (t, τ ) as defined in Equations 2 and 3 respectively. The probabilities in the
numerators and denominators are taken to be outputs of the propensity networks as below:
f (An |Ān−1 ) = rA1 (An |Ān−1 ; θ A1 ) (18)
f (An |H̄n ) = rA2 (An |H̄n ; θ A2 ) (19)
f (Cn = 0|T > n, Ān−1 ) = rC1 (Ān−1 ; θ C1 ) (20)
f (Cn = 0|T > n, L̄n−1 , Ān−1 , X) = rC1 (L̄n−1 , Ān−1 , X; θ C2 ) (21)

Encoder The encoder is also defined in a similar fashion below, with an additional function to
output the internal states of the LSTM g̃E1 (L̄t , Āt , X; θ E1 ). The encoder also computes the one-
step-ahead predictions, i.e. g(1, a(t, 0), H̄t ) as per Equation 1, which is to define the prediction error
e(i, t, 1) and encoder loss Lencoder – i.e.Equations 4 and 5 respectively.

g(1, a(t, 0), H̄t ) = rE (L̄t , Āt , X; θ E ) (22)


ht = r̃E (L̄t , Āt , X; θ E ) (23)

14
Decoder The decoder then uses the seq2seq architecture to project encoder states ht forwards in
time, incorporating planned future actions at+τ . This is also combined with the IPTWs to define the
decoder loss Ldecoder in Equation 5.

g(τ, a(t, τ − 1), H̄t ) = rD (ht , at+1 , . . . , at+τ ; θ D ), ∀τ > 1 (24)

Algorithm 1 Training Process for R-MSN


Input: Training/Validation Data L̄1:T , Ā1:T , X
Output: Neural network weights and hyperparameters for:
1) SW(t, τ ) networks: θ A1 , θ A2
2) SW ∗ (t, τ ) networks: θ C1 , θ C2
3) Encoder network: θ E1 , θ E2
4) Decoder network: θ D1 , θ D2
1:
2: Step 1: Fit Propensity
 Networks
P 
3: θ A1 ← optimize n,i binary_x_entropy(rA1 (An (i)|Ān−1 (i); θ A1 ), An (i))

P 
4: θ A2 ← optimize n,i binary_x_entropy(rA2 (An (i)|H̄n (i); θ A2 ), An (i))

P 
5: θ C1 ← optimize n,i binary_x_entropy(rC1 (Ān−1 (i); θ C1 (i)), Cn (i))

P 
6: θ C2 ← optimize n,i binary_x_entropy(rC2 (L̄n−1 (i), Ān−1 (i), X(i); θ C2 ), Cn (i))
7:
8: Step 2: Generate IPTWs
9: for patient i = 1 to I do
10: for t = 1 to T do
11: for τ = 1 to τmaxQdo
t+τ
12: SWi (t, τ ) ← n=t rA1 (An (i)|Ān−1 (i); θ A1 ) / rA2 (An (i)|H̄n (i); θ A2 )

Qt+τ
13: SWi (t, τ ) ← n=t rC1 (Ān−1 (i); θ C1 (i)) / rC2 (L̄n−1 (i), Ān−1 (i), X(i); θ C2 )
14: end for
15: end for
16: end for
17:
18: Step 3: Fit Encoder 
19: θ E ← optimize Lencoder , as per Equation 5a
20:
21: Step 4: Compute Encoder States {Used to Initialize Decoder}
22: for patient i = 1 to I do
23: for t = 1 to T do
24: ht (i) ← g̃E (L̄t (i), Āt (i), X(i); θ E )
25: end for
26: end for
27:
28: Step 5: Fit Decoder 
29: θ D ← optimize Ldecoder , as per Equation 5b

15
F Hyperparameter Optimization for R-MSN

For the R-MSN, 10,000 simulated paths were used for backpropagation of the network (training data),
and 1,000 simulated paths for hyperparameter optimization (validation data) - with another 1,000 for
out-of-sample testing. Given the differences in state initialization requirements and data batching
of the decoder, we report the hyperparameter optimization settings separately for the decoder. The
optimal parameters of all networks can be found in Table 5.
Settings for Propensity Networks and Encoder Hyperparameter optimization was performed
using 50 iterations of random search, using the hyperparameter ranges in Table 3, and networks
were trained using the ADAM optimizer [20]. For each set of sampled, simulation trajectories were
grouped into B minibatches and networks were trained for a maximum of 100 epochs. LSTM state
sizes were also defined in relation to the number of inputs for the network C.

Table 3: Hyperparameter Search Range for Propensity Networks and Encoder

Hyperparameter Search Range


Hyperparameter Search Iterations 50
Dropout Rate 0.1 , 0.2 , 0.3, 0.4, 0.5
State Size 0.5C, 1C, 2C, 3C, 4C
Minibatch Size 64, 128, 256
Learning Rate 0.01, 0.005, 0.001
Max Gradient Norm 0.5, 1.0, 2.0

Settings for Decoder To train the decoder, the data was reformatted into sequences of
(ht , {Lt+1 , . . . , Lt+τmax }, {At , . . . , At+τmax , X}), such that each patient i max Ti contributions
to the training dataset. Given the T -fold increase in the number of rows in the overall dataset, we
made a few modifications to the range of hyperparameter search, including increasing the size of
minibatches and reducing the learning rate and number of iterations of hyperparameter search. The
full range of hyperparameter search can be found in Table 4 and networks are trained for maximum
of 100 epochs as well.

Table 4: Hyperparameter Search Range for Decoder

Hyperparameter Search Range


Iterations of Hyperparameter Search 20
Dropout Rate 0.1 , 0.2 , 0.3, 0.4, 0.5
State Size 1C, 2C, 4C, 8C, 16C
Minibatch Size 256, 512, 1024
Learning Rate 0.01, 0.005, 0.001, 0.0001
Max Gradient Norm 0.5, 1.0, 2.0, 4.0

Table 5: Optimal Hyperparameters for R-MSN

Dropout Rate State Size Minibatch Size Learning Rate Max Norm
Propensity Networks
f (An |Ān−1 ) 0.1 6 (3C) 128 0.01 2.0
¯
f (An |Ĥn ) 0.1 16 (4C) 64 0.01 1.0
f (Cn = 0|T > n, Ān−1 ) 0.2 4 (2C) 128 0.01 0.5
f (Ct = 0|T > n, L̄n−1 , Ān−1 , X) 0.1 16 (4C) 64 0.01 2.0
Prediction Networks
Encoder 0.1 16 (4C) 64 0.01 0.5
Decoder + Memory Adapter 0.1 16 (8C) 512 0.001 4.0

16
G Hyperparameter Optimization for BTRC
The parameters of the BTRC were optimized using the maximum-a-posteriori (MAP) estimation,
using the same prior for global parameters and approach defined in [35]. While the model was
replicated as faithfully to the specifications as possible, two slight modifications were made to adapt
it to our problem. Firstly, the sparse GP approximations were avoided to ensure that we had as
much accuracy as possible - using Gaussian Process with full covariance matrices for the random
effects components. Secondly, as our dataset was partitioned to ensure that patient observed in the
training set were not present in the test set, this means that any patient-specific parameters learned
would not be used in the testing set itself. As such, to avoid optimizing on the test set, we adopt
the standard approach for prediction in generalized linear mixed models [25], using the average
population parameters, i.e. the global MAP estimate, for prediction.
Hyperparameter optimization was performed using grid search on the optimizer settings defined in 6,
and was performed for a maximum of 5000 epochs per configuration. As convergence was observed
to be slow for a number of settings, we also trained a reduced form of the full BTRC model without
the "shared" parameters (indicated by ’-’ in Table7) to reduce the number of parameters of the model.
The optimal global hyperparameters and optimizer settings can be found in Table 7.

Table 6: Hyperparameter Grid for BTRC

Hyperparameter Search Range


Minibatch Size 2, 5, 10, 100, 500
Learning Rate 10−1 , 10−2 , 10−3 , 10−4 , 10−5

Table 7: MAP Estimates for BTRC


BTRC R-BTRC
χ̄chemo , χ̄radio (-1.3729433, 0.065) (-1.162, 0.007)
ᾱchemo , ᾱradio , (0.760, 0.367), (0.547, 0.367),
(0) (0)
ᾱchemo , ᾱradio (0.207, 0.490) ( -, -)
β̄chemo ,β̄radio , (0.595, 0.367), (0.429, 0.368),
(0) (0)
β̄chemo , β̄radio (0.204, 0.368) ( -, -)
γ̄ -0.27 -0.262
ω̄ -0.928 -
¯lg 1.223 -
κ̄ 0.786 0.867
¯lv 1.092 1.151
σ̄ 2 0.036 0.042
Learning Rate 0.001 0.001
Minibatch Size 100 100

17
Chapter 6

Extracting Useful Features From


High Frequency Data

Publications Included
• B. Lim, S. Zohren, S. Roberts. Detecting Changes in Asset Co-Movement
Using the Autoencoder Reconstruction Ratio. Risk, 2020. Weblink: https:
//arxiv.org/abs/2002.02008.

124
Detecting Changes in Asset Co-Movement
Using the Autoencoder Reconstruction Ratio

Bryan Lim, Stefan Zohren, and Stephen Roberts

Abstract—Detecting changes in asset co-movements is of when they occur, resulting in financial contagion. The
much importance to financial practitioners, with numerous development of methods to detect real-time changes in
risk management benefits arising from the timely detection returns co-movement can thus lead to improvements in
of breakdowns in historical correlations. In this article, we multiple risk management applications.
propose a real-time indicator to detect temporary increases
Common approaches for co-movement detection can be
in asset co-movements, the Autoencoder Reconstruction
Ratio, which measures how well a basket of asset re-
broadly divided into two categories. Firstly, multivariate
turns can be modelled using a lower-dimensional set of Gaussian models have been proposed to directly model
latent variables. The ARR uses a deep sparse denoising time-varying correlation dynamics – with temporal effects
autoencoder to perform the dimensionality reduction on either modelled explicitly (Aı̈t-Sahalia & Xiu, 2016;
the returns vector, which replaces the PCA approach of Dungey et al., 2018), or incorporated simply by recali-
the standard Absorption Ratio (Kritzman et al., 2011), brating covariance matrices daily using a sliding window
and provides a better model for non-Gaussian returns. of data (Preis et al., 2012). More adaptive correlation
Through a systemic risk application on forecasting on estimates can also be provided by calibrating models
the CRSP US Total Market Index, we show that lower
using high-frequency data – as seen in the multivariate
ARR values coincide with higher volatility and larger
drawdowns, indicating that increased asset co-movement HEAVY models of Noureldin et al. (2012). However,
does correspond with periods of market weakness. We several challenges remain with the modelling approach.
also demonstrate that short-term (i.e. 5-min and 1-hour) For instance, naı̈ve model specification can lead to
predictors for realised volatility and market crashes can spurious statistical effects (Loretan & English, 2000) re-
be improved by including additional ARR inputs. flecting correlation increases even when underlying model
parameters remain unchanged. Furthermore, given that
I NTRODUCTION correlations are modelled in a pair-wise fashion between
The time-varying nature of asset correlations has long assets, condensing a high-dimensional correlation matrix
been of much interest to financial practitioners, with into a single number measuring co-movement changes
short-term increases in the co-movement of financial can be challenging in practice.
returns widely documented around periods of market An alternative approach adopts the use of statistical
stress (Preis et al., 2012; Packham & Woebbeking, 2019; factors models to capture common movements across
Campbell et al., 2002; Cappiello et al., 2006; Billio et al., the entire portfolio, typically projecting returns on low-
2010). From a portfolio construction perspective, short- dimensional set of latent factors using principal compo-
term increases in asset correlations have frequently been nent analysis (PCA) (Kritzman et al., 2011; Billio et al.,
linked to spikes in market volatility (Campbell et al., 2010; Zheng et al., 2012). The popular Absorption Ratio
2002) – with explanatory factors ranging from impactful (AR), for example, is defined by the fraction of the total
news announcements (Aı̈t-Sahalia & Xiu, 2016) to “risk- variance of assets explained or absorbed by a finite set
on risk-off” effects (Dungey et al., 2018) – indicating of eigenvectors (Kritzman et al., 2011) – which bears
that diversification can breakdown precisely when it is similarity to metrics to evaluate eigenvalue significance
needed the most. In terms of systemic risk, increased in other domains (e.g. the Fractional Spectral Radius of
co-movement between market returns is viewed as a Rezek & Roberts (1998)). Despite their ability to quantify
sign of market fragility (Kritzman et al., 2011), as the co-movement changes with a single metric, with a larger
tight coupling between markets can deepen drawdowns AR corresponding to increased co-movement, PCA-based
indicators require a covariance matrix to be estimated
B. Lim, S. Zohren and S. Roberts are with the Department of
over a historical lookback window. This approach can be
Engineering Science and the Oxford-Man Institute of Quantitative
Finance, University of Oxford, Oxford, United Kingdom (email:{blim, data-intensive for applications with many assets, as a long
zohren, sjrob}@robots.ox.ac.uk). estimation window is required to ensure non-singularity
of the covariance matrix (Billio et al., 2010). Moreover, we evaluate the performance of the ARR by considering
non-Gaussian financial returns (Jondeau et al., 2007) an application in systemic risk – using increased co-
can violate the normality assumptions required by PCA, movement of sector index returns to improve volatility
potentially making classical PCA unsuitable for high- and drawdown predictions for the total market. The
frequency data. While alternatives such as Independent information content of the ARR is evaluated by measuring
Component Analysis (ICA) have been considered for the performance improvements of machine learning
modelling non-Gaussian financial time series data (Shah predictors when the ARR is included as a covariate. Based
& Roberts, 2013), they typically maintain the assumption on experiments performed for 4 prediction horizons (i.e.
that assets are linear combinations of driving factors 5-min, 1-hour, 1-day, 1-week), the ARR was observed
– and potential improvements can be made using non- to significantly improve performance for volatility and
linear approaches. As such, more adaptive non-parametric market crashes forecasts over short-term (5-min and 1-
methods are hence needed to develop real-time indicators hour) horizons.
for changes in co-movement.
P ROBLEM D EFINITION
Advances in deep learning have demonstrated the
benefits of autoencoders architectures for dimensionality For a portfolio of N assets, let r(t, n) be the returns
reduction in complex datasets (Goodfellow et al., 2016). of the n-th asset at time t defined as below:
In particular, autoencoders have had notable successes r(t, n) = log p(t, n) − log p(t − ∆t, n), (1)
in feature extraction in images (Cho, 2013; Liu &
Zhang, 2018; Freiman et al., 2019), vastly out-performing where p(t, n) is the value or price of asset n at time t,
traditional methods for dimensionality reduction in non- and ∆t is a discrete interval corresponding to the desired
linear datasets. More recently, autoencoders have been sampling frequency.
explored as a replacement for PCA in various financial In their most general form, statistical factor models
applications, encoding latent factors which account for map a common set of K latent variables to the returns
non-linearities in return dynamics and allow for condi- of each asset as below:
tioning on exogenous covariates. Gu et al. (2019), for r(t, n) = r̄(t, n) + (t, n) (2)
instance, use a series of encoders to estimate both latent 
factors and factor loadings (i.e. betas) used in conditional = fn Z(t) + (t, n), (3)

asset pricing models, demonstrating better out-of-sample where the residual (t, n) ∼ N 0, ξn2 is the idiosyncratic
pricing performance compared to traditional linear meth- risk of a given asset, and fn (.) is a generic function, and
ods. Kondratyev (2018) explores the use of autoencoders Z(t) = [z(t, 1), . . . z(t, K)]T is a vector of common
to capture a low-dimensional representation of the term latent factors.
structure of commodity futures, and also shows superior Traditional factor models typically adopt a simple linear
reconstruction performance versus PCA. Given the strong form, using asset-specific coefficients as below:
performance enhancements of autoencoder architectures 
on non-linear datasets, replacing traditional factor models fn Z(t) = β(n)T Z(t), (4)
with autoencoders could be promising for co-movement where β(n) = [β(n, 1), . . . β(n, K)]T is a vector of
measurement applications as well. factor loadings. While various approaches are available
In this article, we introduce the Autoencoder Recon- for latent variable estimation – such as independent
struction Ratio (ARR) – a novel indicator to detect co- component analysis (ICA) (Fabozzi et al., 2015) or PCA
movement changes in real-time – based on the average (Billio et al., 2010) – the number of latent variables are
reconstruction error obtained from applying autoencoders typically kept low to reduce the dimensionality of the
to high-frequency returns data. Adopting the factor dataset (i.e. K  N ).
modelling approach, the ARR uses deep sparse denoising
autoencoders to project second-level intraday returns onto Absorption Ratio
a lower-dimension set of latent variables, and aggregates A popular approach to measuring asset co-movement
reconstruction errors up to the desired frequency (e.g. is the Absorption Ratio (Kritzman et al., 2011), which
daily). Increased co-movement hence corresponds to performs dimensionality reduction on returns using PCA.
periods where reconstruction error is low, i.e. when Co-movement changes are then measured based on the
returns are largely accounted for by the autoencoder’s total variance absorbed by a finite set of eigenvectors.
latent factors. In line with the canonical Absorption Ratio, With K set to be a fifth of the number N of available
assets per Kritzman et al. (2011), the Absorption ratio is corresponds to the variance unexplained by the selected
defined as: PK factors. The ARR can in this case be expressed as:
2
k=1 σEk PN PK
AR(t) = PN , (5) 2 2
2 n=1 σAn − k=1 σEk
n=1 σAn ARR(t, ∆t) = PN (10)
2
n=1 σAn
where AR(t) is the Absorption Ratio at time t. σE 2 is
k
= 1 − AR(t) (11)
the variance of the k -th largest eigenvector of the PCA
decomposition, and σA 2 is the variance of the n-th asset.
n Autoencoder Architecture
AUTOENCODER R ECONSTRUCTION R ATIO We adopt a deep sparse denoising autoencoder
To detect changes in asset co-movements, we propose architecture (Goodfellow et al., 2016) to compute
the Autoencoder Reconstruction Ratio below, which can reconstruction errors for the ARR. The network is
be interpreted as the normalised reconstruction mean predominantly divided into two parts – 1) a decoder
squared error (MSE) of high-frequency returns within a which reconstructs the returns vector from a reduced
given time interval: set of latent factors, and 2) an encoder which performs
Pt PN 2 the low-dimensional projection – both of which are
τ =t−∆t n=1 r(τ, n) − r̄(τ, n) described below.
ARR(t, ∆t) = Pt PN ,
2
τ =t−∆t n=1 r(τ, n)
(6) Decoder:
where ∆t is the time interval matching the desired r̄(t) = W1 h1 (t) + b1 (12)
sampling frequency, and r̄(τ, n) is the return of asset h1 (t) = ELU(W2 Z(t) + b2 ) (13)
n reconstructed by a deep sparse denoising autoencoder
(see sections below). For our experiments, we train the where r̄(t) = [r̄(t, 1), . . . , r̄(t, N )]T is the vector
autoencoder using one-second returns and compute ARRs of reconstructed returns, ELU(.) is the exponential
across four different sampling frequencies (i.e. 5-min, 1- linear unit activation function (Clevert et al., 2016),
hour, 1-day, 1-week). h1 (t) ∈ RH is the hidden state of the decoder network,
and W1 ∈ RN ×H , W2 ∈ RH×K , b1 ∈ RN , b2 ∈ RH
Relationship to the Absorption Ratio
are its weights and biases.
The ARR can also be interpreted as a slight reformu-
lation of the standard Absorption Ratio, and we examine Encoder:
the relationship between the two in this section.
Reintroducing the linear Gaussian assumptions used Z(t) = ELU (W3 h2 (t) + b3 ) (14)
by the Absorption Ratio, we note that the denominator h2 (t) = ELU(W4 x(t) + b4 ) (15)
of Equation (6) contains a simple estimator for the
realised variance (RV) of each asset (Barndorff-Nielsen where Z(t) ∈ R is a low-dimensional projection of
K

& Shephard, 2002), i.e.: inputs x(t) = [r(t, 1), . . . , r(t, N ), T ]T , h2 (t) ∈ RH is
t
the hidden state of the encoder network, and W3 ∈
X H are its
(7) R 4 ∈ R 3 ∈ R , b4 ∈ R
K×H , W H×N , b K
RV(t, ∆t, n) = r(t, n)2
τ =t−∆t
weights and biases. We note that the time-of-day T –
2 recorded as the number of seconds from midnight – is also
≈ σA . (8)
n
included in the input along with sector returns, allowing
Moreover, by comparing to Equation (2), we can see the autoencoder to account for any intraday seasonality
that the numerator can be interpreted as an estimate of present in the dataset.
the sum of residual variances, leading to the form of the To ensure that dimensionality is gradually reduced, we
ARR below: set K = bN/5c as per the original Absorption Ratio
PN 2 paper, and fix H = b(N + K)/2c.
n=1 ξn
ARR(t, ∆t) = PN . (9)
2
n=1 σAn Network Training
Assuming that residuals are uncorrelated between assets, To improve generalisation on test data, sparse denoising
we note that the sum of residual variances essentially autoencoders introduce varying degrees of regularisation.
Firstly, a sparsity penalty in the form of a L1 regularisa- Ticker Index Description
tion term is added to the reconstruction loss as below: CRSPTMT CRSP US Total Market Total-Return Index
PT PN
(r(t, n) − r(t,¯ n))2 CRSPRET CRSP US REIT Total-Return Index
L(Θ) = t=1 n=1 CRSPENT CRSP US Oil and Gas Total-Return Index
TN CRSPMTT CRSP US Materials Total-Return Index
+ αkZ(t)k1 , (16) CRSPIDT CRSP US Industrials Total-Return Index
CRSPCGT CRSP US Consumer Goods Total-Return Index
where Θ represents all network weights, α is a penalty CRSPHCT CRSP US Health Care Total-Return Index
weight which we treat as a hyperparameter, and k.k1 is CRSPCST CRSP US Consumer Services Total-Return Index
the L1 norm. Secondly, inputs are corrupted with noise CRSPTET CRSP US Telecom Total-Return Index
CRSPUTT CRSP US Utilities Total-Return Index
during training – forcing the autoencoder to learn more CRSPFNT CRSP US Financials Total-Return Index
general relationships from the data. Specifically, we adopt CRSPITT CRSP US Technology Total-Return Index
masking noise for our network, which simply corresponds
to the application of dropout (Srivastava et al., 2014) to Exhibit 1: US Maket & Sector Indices Used in Tests
encoder inputs.
Hyperparameter optimisation is performed using 20
(2016-2019) used to quantify the information content of
iterations of random search, with networks trained up
the ARR. PCA factors were calibrated using covariance
to a maximum of 100 epochs per search iteration. Full
matrix estimated with data from 2012-2016, keeping
hyperparameter details are listed in the appendix for
the same out-of-sample data as the autoencoder. The
reference.
returns vector is projected onto 2 latent variables, with
F ORECASTING S YSTEMIC R ISK WITH THE ARR the dimensionality of the latent space taken to be 1/5 the
number of indices used – as per the original Absorption
To demonstrate the utility of the ARR, we consider
Ratio paper (Kritzman et al., 2011).
applications in systemic risk forecasting – using ARRs
We evaluate the reconstruction accuracy using the test
computed from subsector indices to predict risk metrics
dataset, based on the R-squared (R2 ) of reconstructed
associated with the overall market. Specifically, we
returns for each prediction model. Given that the ARR
consider the two key use-cases for the ARR:
is computed based on sector data alone, we only focus
1) Volatility Forecasting – i.e. predicting market tur- on the reconstruction accuracy of the 11 sector indices.
bulence as measured by spikes in realised volatility; We also test for the significance of the results with a
2) Predicting Market Crashes – i.e. providing an bootstrap hypothesis test – using 500 bootstrap samples
early warning signal for sudden market declines, and adopting the null hypothesis that there is no difference
allowing for timely risk management. in R2 between the autoencoder and PCA reconstruction.
Description of Dataset
PCA Autoencoder P-Value
We focus on the total US equity market and 11 R2 0.340 0.461* < 0.01
constituent sector indices for our investigation, using
high frequency total returns sampled every second to Exhibit 2: Out-of-Sample Reconstruction R2 (2016-2019)
compute the ARR. Intraday index data from the Center
of Research in Security Prices (CRSP) was downloaded From the results in Exhibit 2, we can see that the
via Wharton Research Data Services (WRDS (2019)) autoencoder greatly enhances reconstruction accuracy
from 2012-12-07 to 2019-03-29, with the full list of – increasing the out-of-sample R2 by more than 35%
indices provided in Exhibit 1. and statistically significant with 99% confidence. These
improvements highlight the benefits of adopting a non-
Reconstruction Performance
linear approach to dimensionality reduction, and the
We compare the reconstruction accuracy of the deep suitability of the autoencoder for modelling intraday
sparse denoising autoencoder and standard PCA to evalu- returns. Such an approach can have wider applications
ate the benefits of a non-linear approach to dimensionality in various areas of quantitative finance.
reduction. To train the autoencoder, we divide the data
into a training set used for network backpropagation Empirical Analysis
(2012-2014), a validation set used for hyperparameter Next, we perform an exploratory investigation into the
optimisation (2015 only), and a out-of-sample test set relationships between the total market index, and the
ARR of its constituent sectors. Specifically, we examine Q UANTIFYING THE I NFORMATION C ONTENT
the following metrics aggregated over different sampling OF THE ARR
frequencies (∆t ∈ {5-min, 1-hour, 1-day, 1-week}): We attempt to quantify the information content of the
• Returns – Corresponding to the difference in log ARR by observing how much it improves forecasting
prices based on Equation (1). models for the risk metrics above. To do so, we adopt
• Log Realised Volatility (Log RV) – Computed several machine learning benchmarks for risk prediction,
based on the logarithm of the simple realised and evaluate the models’ forecasting performance both
volatility estimator of Equation (7). with and without the ARR included in its inputs. This
• Drawdowns (DD) – Determined by the factional allows us to evaluate how much additional information is
decrease from the maximum index value from the provided by the ARR, above that provided by temporal
start of our evaluation period. evolution of realised volatility or drawdowns alone.
ARRs were similarly computed for the 4 sampling Benchmark Models
frequencies, and aggregated based on Equation (6). Using
Given the non-linear relationships observed between
the KDE plots of Exhibits 3 to 5, we perform an in-
returns and drawdowns versus the ARR, we adopt a series
sample (2012-2015) empirical analysis of the coincident
of machine learning benchmarks on top of linear models.
relationships between the ARR and the metrics above.
Specifically, we consider 1) linear/logistic regression, 2)
To avoid spurious effect, we remove the eves of public
gradient boosted decision trees (GBDTs), and 3) simple
holidays from our analysis – where markets are open
multi-layered perceptrons (MLPs) for our forecasting
only in the morning and trading volumes are abnormally
applications. Hyperparameter optimisation is performed
low. In addition, we winsorise the data at the 1st and 99th
using up to 200 iterations of random search, with optimal
percentiles, reducing the impact of outliers to improve
hyperparameters selected using 3-fold cross validation.
data visualisation.
Additional training details can also be found in the
Looking at the KDE plots for returns against the appendix.
ARR in Exhibit 3, we observe a noticeable increase
in the dispersion of returns appears at low ARR values. Volatility Forecasting Methodology
This is echoed by the KDE plots for Log RV in 4 – We treat volatility forecasting as a regression problem,
which appears to increase linearly with decreasing ARR. focusing on predicting Log RV over different horizons –
Drawdowns behave in a similar fashion as well, with using a variant of the HAR-RV model of Corsi (2009).
larger drawdowns appearing to occur at low values of For 5-min Log RV, a simple HAR-RV model can take
the ARR. Effects also are consistent across different the form below, incorporating log RV terms for longer
horizons, with similar patterns observed for all sampling horizons into the forecast:
frequencies. On the whole, the results validate the findings
observed in previous works – indicating that increased ν̄(t + 5-min, 5-min)
asset co-movements do indeed coincide with periods of = β1 ν(t, 5-min) + β2 ν(t, 1-hour)
market weakness – with lower ARR values observed + β3 ν(t, 1-day) + β4 ν(t, 1-week), (17)
around periods of high volatility or high drawdowns.
To visualise how the ARR changes over time, we plot where ν̄(t + 5-min, 5-min) is 5-min log RV predicted
5-min ARR (i.e. ARR(t, 5-min)) against the prices and over the next time-step, ν(t, ∆t) is the log RV from
drawdowns of the CRSP Total Market Index in Exhibit t − ∆t to t, and β(.) are linear coefficients.
6. Given the noisiness of the raw ARR values, we also We adopt a similar non-linear variant for our forecast
included a smoothed version of 5-min ARR in the bottom for each prediction horizon δ . For tests with the ARR,
– computed based on an exponentially weighted moving this takes the form:
average with a 1-day half-life. We can see from the results ν̄(t + δ, δ) = g1 (ψ(t), ω(t)), (18)
that dips in the ARR occur slightly before large drawdown
periods, particularly around the sell-offs of August 2015 where δ ∈ Ψ, Ψ = {5-min, 1-hour, 1-day, 1-week},
and early 2018. As such, the ARR could potentially be ψ(t) = {ν(t, d) : d ∈ Ψ and d >= δ}, ω(t) =
used to improve predictions of volatility spikes or sudden {ARR(t, d) : d ∈ Ψ and d >= δ}, and g2 (.) is a predic-
market crashes in the near future – which we further tion model mapping inputs to returns. The combination
investigate in the next section. of both realised volatility and ARR values hence allows
(a) ∆t = 5-min (b) ∆t = 1 hour (c) ∆t = 1-day (d) ∆t = 1-week
Exhibit 3: In-sample KDE Plots for Returns vs. ARR.

(a) ∆t = 5-min (b) ∆t = 1 hour (c) ∆t = 1-day (d) ∆t = 1-week


Exhibit 4: In-Sample KDE Plots for Log RV vs. ARR.

(a) ∆t = 5-min (b) ∆t = 1 hour (c) ∆t = 1-day (d) ∆t = 1-week


Exhibit 5: In-Sample KDE Plots for Drawdowns vs. ARR.

Exhibit 6: 5-min ARR from 2012 to 2019.


us to determine if the ARR supplies any information significant improvements are observed for linear models
above that provide the volatility time series alone. for 5-min sampling frequencies. This indicates that the
For tests without the ARR, we continue to use Equation ARR is informative for short-term forecasts, and can
(18) but omit ARR values, i.e. ω(t) = ∅. We evaluate help enhance risk predictions in the near-term. For longer
regression performance using the R2 of forecasted log horizons however (i.e. 1-day and 1-week) we observe
RV. that the inclusion of the ARR reduces prediction accuracy
– potentially indicating the presence of overfitting on the
Crash Prediction Methodology
training set when ARRs are introduced.
We take a binary classification approach to forecasting For crash predictions AUROC results in Exhibit 7b,
market crashes, using our benchmark models to predict we note that the ARRs are observed to improve forecasts
the onset of a sharp drawdown. First, we define a z -score for all models and sampling frequencies – with statistical
metric for returns γ(t, λ) as below: significance at the 99% level observed for both linear
r(t) − m(t, λ) and GBDT forecasts over shorter horizons. This echoes
γ(t, λ) = , (19)
s(t, λ) the volatility forecasting results – indicating that ARRs
where r(t) is the return for the total market index, m(t, λ) can be useful to inform short-term risk predictions.
is the exponentially weighted moving average for returns C ONCLUSIONS
with a half-life of λ, and s(t, λ) its exponentially weighted
We introduce the Autoencoder Reconstruction Ratio
moving standard deviation.
(ARR) in this paper, using it as a real-time measure
Next, based on our z-score metric, we define a market
of asset co-movement. The ARR is based on the nor-
crash to be a sharp decline in market returns, i.e.:
malised reconstruction error of a deep sparse denoising
c(t) = I(γ(t, λ) < C), (20) autoencoder applied to a basket of asset returns – which
where c(t) ∈ 1, 0 is a crash indicator, I(.) is an indicator condenses the returns vector onto a lower dimensional
function, and C is the z -score threshold for returns. For set of latent variables. This replaces the PCA modelling
our experiments, we set λ to be 10 discrete time steps approach used by the Absorption Ratio of Kritzman et al.
and C = −1.5. (2011), which allows the ARR to better model returns
We then model crash probabilities using a similar form that violate basic PCA assumptions (e.g. non-Gaussian
to Equation (18): returns). Through experiments on a basket of 11 CRSP
US sector indices, we demonstrate that the autoencoder
p(c(t) = 1) = g2 (ψ(t), ω(t)), (21) significantly improves the out-of-sample reconstruction
where g2 (.) is a function mapping inputs to crash performance2 when compared to PCA, increasing the
probabilities p(c(t) = 1). combined R by more than 35%.
Given that crashes are rare by definition – with c(t) = 1 Given the links identified between increased asset co-
for less than 10% of time steps for daily frequencies – movements and the fragility of the overall market in
we also oversample the minority class to address the previous works (Kritzman et al., 2011; Campbell et al.,
class imbalance problem. Classification performance is 2002), we also evaluate the use of the ARR in a systemic
evaluated using the area under the receiver operating risk application. First, we conduct an empirical analysis
characteristic (AUROC). of the relationship between risk metrics of the combined
market (using the CRSP US Total Market Index as a
Results and Discussion proxy) and the ARR computed from its sub sectors of
The results for both volatility forecasting and market the market. Based on an analysis of the KDE plots of risk
crash prediction can be found in Exhibit 7, both including metrics vs. ARRs, we show that low values of the ARR
and excluding ARR values. To determine the statistical coincide with high volatility and high drawdown periods
significance of improvements, we conduct a bootstrap in line with previous findings. Next, we evaluate the
hypothesis test under the null hypothesis that performance information content of the ARR by testing how much the
results are better when the ARR is included, using a non- ARR improves risk predictions for a various benchmark
parametric bootstrap with 500 samples. models. We find that the ARR is informative for both
From the volatility forecasting R2 values in Exhibit volatility and market crash predictions over short horizons,
7a, we can see that the ARR consistently improves 5- and significantly increases 5-min and 1-hour forecasting
min and 1-hour forecasts for all non-linear models, and performance across most model benchmarks.
5-min 1-hour 1-day 1-week 5-min 1-hour 1-day 1-week
Linear With ARR 0.635* 0.496 0.389 0.579 Linear With ARR 0.598* 0.629* 0.585 0.372
No ARR 0.626 0.536* 0.431* 0.611* No ARR 0.587 0.570 0.582 0.333
P-Values <0.01 >0.99 >0.99 >0.99 P-Values <0.01 <0.01 0.464 0.284
GBDT With ARR 0.635* 0.554* 0.425 0.399 GBDT With ARR 0.590* 0.562* 0.529 0.605
No ARR 0.627 0.536 0.434 0.593* No ARR 0.567 0.516 0.465 0.575
P-Values <0.01 <0.01 0.821 >0.99 P-Values <0.01 0.01 0.138 0.359
MLP With ARR 0.641* 0.571* 0.387 0.527 MLP With ARR 0.589 0.564 0.586 0.423
No ARR 0.631 0.549 0.426* 0.636* No ARR 0.588 0.558 0.540 0.210
P-Values <0.01 <0.01 >0.99 >0.99 P-Values 0.374 0.392 0.234 <0.01

(a) R2 for Log RV Predictions (b) AUROC for Crash Predictions


Exhibit 7: Out-of-Sample Forecasting Results With and Without ARR Inputs.

R EFERENCES Gu, Shihao, Kelly, Bryan T., & Xiu, Dacheng. 2019. “Autoencoder
asset pricing models,”. Yale ICF Working Paper No. 2019-04;
Abadi, Martı́n, et al. 2015. TensorFlow: Large-Scale Machine
Chicago Booth Research Paper No. 19-24. https://ssrn.com/
Learning on Heterogeneous Systems. Software available from
abstract=3335536.
tensorflow.org.
Jondeau, Eric, Poon, Ser-Huang, & Rockinger, Michael. 2007.
Aı̈t-Sahalia, Yacine, & Xiu, Dacheng. 2016. “Increased correlation
Financial modeling under non-gaussian distributions. Springer
among asset classes: Are volatility or jumps to blame, or both?,”.
Finance. Springer.
Journal of Econometrics, 194(2), 205–219.
Ke, Guolin, et al. 2017. “Lightgbm: A highly efficient gradient
Barndorff-Nielsen, Ole E., & Shephard, Neil. 2002. “Estimating
boosting decision tree,”. Page 3149–3157 of: Proceedings of the
quadratic variation using realized variance,”. Journal of Applied
31st International Conference on Neural Information Processing
Econometrics, 17(5), 457–477.
Systems. NIPS’17. Red Hook, NY, USA: Curran Associates Inc.
Billio, Monica, Getmansky, Mila, Lo, Andrew W., & Pelizzon, Loriana.
Kondratyev, Alexei. 2018. “Learning curve dynamics with artificial
2010. “Measuring systemic risk in the finance and insurance
neural networks,”. SSRN. https://ssrn.com/abstract=3041232.
sectors,”. MIT Sloan Research Paper No. 4774-10. http://ssrn.com/
Kritzman, Mark, Li, Yuanzhen, Page, Sébastien, & Rigobon, Roberto.
abstract=1571277.
2011. “Principal components as a measure of systemic risk,”. The
Campbell, Rachel, Koedijk, Kees, & Kofman, Paul. 2002. “Increased
Journal of Portfolio Management, 37(4), 112–126.
correlation in bear markets,”. Financial Analysts Journal, 58(1),
Liu, Yan, & Zhang, Yi. 2018. “Low-dose CT restoration via stacked
87–94.
sparse denoising autoencoders,”. Neurocomputing, 284, 80 – 89.
Cappiello, Lorenzo, Engle, Robert F., & Sheppard, Kevin. 2006.
Loretan, Mico, & English, William B. 2000. “Evaluating correlation
“Asymmetric Dynamics in the Correlations of Global Equity and
breakdowns during periods of market volatility,”. Board of
Bond Returns,”. Journal of Financial Econometrics, 4(4), 537–572.
Governors of the Federal Reserve System International Finance
Cho, Kyung Hyun. 2013. “Simple sparsification improves sparse
Working Paper.
denoising autoencoders in denoising highy noisy images,”. In:
Noureldin, Diaa, Shephard, Neil, & Sheppard, Kevin. 2012. “Multi-
Proceedings of the 30th International Conference on Machine
variate high-frequency-based volatility (heavy) models,”. Journal
Learning. ICML 2013.
of Applied Econometrics, 27(6), 907–933.
Clevert, Djork-Arne, Unterthiner, Thomas, & Hochreiter, Sepp.
Packham, N., & Woebbeking, C.F. 2019. “A factor-model approach
2016. “Fast and accurate deep network learning by exponential
for correlation scenarios and correlation stress testing,”. Journal
linear units (ELUs),”. In: International Conference on Learning
of Banking and Finance, 101, 92 – 103.
Representations. ICLR 2016.
Pedregosa, F., et al. 2011. “Scikit-learn: Machine Learning in Python
Corsi, Fulvio. 2009. “A Simple Approximate Long-Memory Model
,”. Journal of Machine Learning Research, 12, 2825–2830.
of Realized Volatility,”. Journal of Financial Econometrics, 7(2),
Preis, T, Kenett, DY, Stanley, HE, Helbing, D, & Ben-Jacob, E. 2012.
174–196.
“Quantifying the behavior of stock correlations under market stress,”.
Dungey, Mardi, Erdemlioglu, Deniz, Matei, Marius, & Yang, Xiye.
Scientific Reports, 752(2).
2018. “Testing for mutually exciting jumps and financial flights in
Rezek, I. A., & Roberts, S. J. 1998. “Stochastic complexity
high frequency data,”. Journal of Econometrics, 202(1), 18 – 44.
measures for physiological signal analysis,”. IEEE Transactions
Fabozzi, Frank, Giacometti, Rosella, & Tsuchida, Naoshi. 2015. The
on Biomedical Engineering, 45(9), 1186–1191.
ICA-based Factor Decomposition of the Eurozone Sovereign CDS
Shah, Nauman, & Roberts, Stephen. 2013. “Dynamically measur-
Spreads. IMES Discussion Paper Series 15-E-04. Institute for
ing statistical dependencies in multivariate financial time series
Monetary and Economic Studies, Bank of Japan.
using independent component analysis,”. International Scholarly
Freiman, Moti, Manjeshwar, Ravindra, & Goshen, Liran. 2019.
Research Notices.
“Unsupervised abnormality detection through mixed structure
Srivastava, Nitish, Hinton, Geoffrey, Krizhevsky, Alex, Sutskever, Ilya,
regularization (MSR) in deep sparse autoencoders,”. Medical
& Salakhutdinov, Ruslan. 2014. “Dropout: A simple way to prevent
Physics, 46(5), 2223–2231.
neural networks from overfitting,”. Journal of Machine Learning
Goodfellow, Ian, Bengio, Yoshua, & Courville, Aaron. 2016. “Au-
Research, 15, 1929–1958.
toencoders,”. Chap. 14 of: Deep Learning. MIT Press. http:
WRDS. 2019. The Center for Research in Security Prices (CRSP)
//www.deeplearningbook.org.
Index History - Intraday. https://wrds-www.wharton.upenn.edu/. • ’solver’ –[’liblinear’]
Zheng, Zeyu, Podobnik, Boris, Feng, Ling, & Li, Baowen. 2012.
“Changes in cross-correlations as an indicator for systemic risk,”.
Scientific Reports, 888(2).
Gradient Boosted Decision Tree
A PPENDIX • Package Name – lightgbm
A. Additional Training Details • Class Name – LGBMRegressor or
Python Libraries: Deep sparse denoising autoencoders LGBMClassifier
are defined and trained using the TensorFlow • ’learning rate’ – [10−4 , 10−3 , 10−2 , 10−1 ]
(Abadi et al., 2015). For Gradient Boosted Decision • ’n estimators’ – [5, 10, 20 ,40, 80, 160, 320]
Trees, we use the LightGBM library (Ke et al., • ’num leaves’ – [5, 10, 20, 40, 80]
2017) – using the standard LightGBMRegressor • ’n jobs’ – [5]
and LightGBMClassifier depending on the • ’reg alpha’ – [0, 10−4 , 10−3 , 10−2 , 10−1 ]
forecasting problem. The remainder of the models are • ’reg beta’ – [0, 10−4 , 10−3 , 10−2 , 10−1 ]
implemented using standard scikit-learn classes • ’boosting type’ – [’gbdt’]
(Pedregosa et al., 2011) – with classes described in the
hyperparameter optimisation section.
Multi-layer Perceptron
Hyperparameter Optimisation Details: Random search • Package Name – sklearn.neural_network
is conducted by sampling over a discrete set of values • Class Name – MLPRegressor or
for each hyperparameter, which are listed below for each MLPClassifier
hyperparameter. For ease of reference, hyperparameters • ’hidden layer sizes’ – [5, 10, 20, 40, 80, 160]
for all scikit-learn and LightGBM classes are • ’activation’ – [’relu’]
referred to by the default argument names used in their • ’alpha’ – [0, 10−4 , 10−3 , 10−2 , 10−1 , 1, 10, 102 ]
respective libraries. • ’learning rate init’ – [10−4 , 10−3 , 10−2 , 10−1 ]
• ’early stopping’ – [True],
Deep Sparse Denoising Autoencoder • ’max iter’ – [500]
• Dropout Rate – [0.0, 0.2, 0.4, 0.6, 0.8]
• Regularisation Weight α – [0.0, 0.01, 0.1, 1.0, 10]
• Minibatch Size – [256, 512, 1024, 2048]
• Learning Rate – [10−5 , 10−4 , 10−3 , 10−2 , 10−1 ,
1.0]
• Max. Gradient Norm – [10−4 , 10−3 , 10−2 , 10−1 ,
1.0, 10.0]

Linear Regression
• Package Name – sklearn.linear_model
• Class Name – LogisticRegression
• ’alpha’ – [10−5 , 10−4 , 10−3 , 10−2 , 10−1 , 1, 10,
102 ],
• ’fit intercept’ – [False, True],

Logistic Regression
• Package Name – sklearn.linear_model
• Class Name – LogisticRegression
• ’penalty’ – [’l1’]
• ’C’ – [0.01, 0.1, 1.0, 10, 100]
• ’fit intercept’ – [False]
Chapter 7

Discussion & Conclusions

In this thesis, we introduce a collection of novel methods to improve the performance


of deep learning models for time series forecasting, and propose extensions to facil-
itate decision making over time. While model design can vary from application to
application – due to the heterogeneity of forecasting problems in different domains
– our central theme is in the closer alignment of models with traditional time se-
ries models. This allows us to customise existing building blocks in deep learning to
better capture the nuances of complex temporal processes, leading to demonstrable
improvements in tests on a wide variety of datasets.

Contributions in Time Series Forecasting: With the growing size and complex-
ity of time series datasets, data-driven deep learning approaches have increasingly
outperformed parametric modelling techniques. In many forecasting applications,
however, neural network architectures are often taken directly from other domains
(e.g. LSTMs and CNNs) and applied with modification, neglecting to account for the
unique characteristics of time series datasets. We hence aim to improve deep learn-
ing models by drawing inspiration from traditional time series modelling – designing
networks with the appropriate inductive biases and outputs for time series data.
For general time series problems, we propose two new architecture for one-step-
ahead and multi-horizon forecasting in Chapter 3. With the Recurrent Neural Filter
(RNF), we propose a new architecture for one-step-ahead prediction, comprising a
series of encoders and decoders that are aligned with the Bayesian filtering steps.
We also show that each stage of the RNF can be separated at run-time, due to
the skip training approach and multi-task loss function used to encourage decoupled
representations. This allows the RNF to be applied using data only when available –
helping the RNF to handle datasets with missing data, and to improve performance
when applied in an autoregressive fashion for multi-horizon prediction. To better

134
accommodate the full range of inputs present in complex multi-horizon forecasting
scenarios, such as static metadata, we also developed a new attention-based model
in Temporal Fusion Transformers (TFT). From a forecasting perspective, the TFT
uses both recurrent and attention layers for local and long-term temporal processing,
and contains a series of gated residual networks to allow the network to suppress
unnecessary components for a given dataset. This allows the TFT to perform well
on a wide range of data regimes, outperforming state-of-the-art methods for multi-
horizon prediction.
To further incorporate domain expertise for specific problems, we examine a range
of hybrid deep learning models in Chapter 4 – focusing on an application in finance and
another in medicine. Hybrid models enhance traditional methods using deep learning
by parameterising well-studied quantitative models for a given domain with neural
network outputs. With Deep Momentum Networks (DMNs), we enhance standard
time series momentum signals using deep learning-based trading rules, demonstrat-
ing improved performance over traditional and machine learning-based systematic
signals. In addition, we augment joint models for longitudinal and time-to-event data
in biostatistics with the the Disease-Atlas, demonstrating improvements in predicting
multiple clinical outcomes in Cystic Fibrosis patients.
Finally, we analyse the use of deep neural networks as feature extraction mech-
anisms in Chapter 6, feeding features produced by autoencoders to improve the ac-
curacy of generic forecasting models. Considering a financial use-case in systemic
risk prediction, we develop the Autoencoder Reconstruction Ratio (ARR) – which is
based on a measure of the reconstruction error for high-frequency returns. The ARR
computes the amount of information captured by a low dimensional set of non-linear
latent variables, allowing us to quantify any changes in asset co-movement over time.
Through tests on index data, we demonstrate that the autoencoder approach provides
a better model for high-frequency returns, and that the ARR can be used to enhance
short-term risk predictions.

Contributions in Decision Support Over Time: While model users do rely


on forecasts as one source of information, they may desire a deeper understanding
of their dataset to guide their actions, e.g. through counterfactual simulations or
the identification of key input features driving predictions. However, deep learning
models can suffer several limitations in this respect. Firstly, deep neural networks
are designed to capture complex correlations, without distinguishing between cause
and effect or accounting for confounding factors. As a result, slight shifts in the

135
input distribution can severely affect model performance even if causal mechanisms
remain unchanged – limiting the use of standard models for counterfactual predictions
[13]. Furthermore, the black-box nature of standard deep neural networks can make
it difficult to interpret the relationships learnt by the model, or to identify which
features are important for forecasts.
As such, on top of innovations in time series forecasting, we also explore two
extensions to deep neural networks to better facilitate decision making over time.
We start by presenting Recurrent Marginal Structural Networks in Chapter 5, which
provides a framework to training deep neural networks to learn causal effects from
observational data. Based on the inverse probability of treatment weighting approach
of marginal structural models in epidemiology, RMSNs use deep neural networks to
learn probabilities of treatment assignment and censoring. The probabilities are then
used to generate weights to loss functions for a prediction network, adjusting neural
network training to account for time-dependent confounding effects. Through tests
with a clinically realistic simulation model, we show that RMSNs outperform state-of-
the-art benchmarks in learning unbiased treatment effects. Moreover, with the TFT,
we demonstrate how attention weights can be analysed to provide general insights into
the temporal relationships present in the dataset. This is showcased through three
interpretability use-cases – 1) analysing feature importance, 2) visualising persistent
temporal patterns, and 3) identifying significant regimes and events.

7.1 Extensions and Future Work


While large strides have been made with the field of deep learning for time series
prediction, the greatest impact has been on univariate time series sampled at discrete
intervals – which are most suited for current neural network designs. In this section,
we propose two directions to extend our research to handle more complex time series
forecasting problems.

Continuous-Time Models: As universal function approximators [9], neural net-


works have traditionally been used to model fixed input-output relationships – making
them inherently suited to handle discrete data. However, forecasting in continuous-
time can be a more suitable approach in many applications, particularly in the case
of datasets with irregularly sampled observations and high-frequency streaming data.
A recent promising direction in deep learning research has been in Neural Ordi-
nary Differential Equation (ODE) and Stochastic Differential Equation (SDE) models

136
[3, 20, 12, 10, 5] – which model the deterministic or stochastic evolution of hidden
states, and naturally handle continuous time forecasts. However, improvements have
mainly been demonstrated using simulated data, which make it difficult to assess the
performance of Neural ODEs/SDEs on real-world time series datasets. In addition
to benchmarking and extensions to state-of-the-art time series architectures, we also
recommend investigations into hybrid continuous-time models – which would allow
for direct comparisons to traditional methods in applications where they are most
beneficial.

Multivariate Time Series: The majority of previous work (see Section 2) has
focused on the development of univariate forecasting models – i.e. assuming that
targets are driven only by inputs for a given entity, and are independent of each other.
In this manner, deep neural networks are trained to capture temporal relationships
that are generally applicable across all entities, without taking into account cross-
sectional relationships between them. While this allows networks to be trained with
a larger pool of data – as mini-batches are sampled across entities and time – there are
instances where multivariate forecast can be beneficial, particularly in non-stationary
datasets. For instance, in portfolio management, market breakdowns can cause a
general decline of all stocks and result in short-time increases in asset correlations. In
retail forecasting, unforeseen events, such as natural disasters, can also lead to spikes
in the joint demand for specific categories of goods. As such, multivariate models
can help improve performance in datasets where common driving factors exist across
entities – motivating the need for extensions to existing univariate architectures.

137
Bibliography

[1] Yoshua Bengio, Aaron Courville, and Pascal Vincent. Representation Learning:
A Review and New Perspectives. arXiv e-prints, page arXiv:1206.5538, Jun 2012.

[2] James V. Candy. Bayesian Signal Processing: Classical, Modern and Particle
Filtering Methods. Wiley-Interscience, New York, NY, USA, 2009.

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud.
Neural ordinary differential equations. In Advances in Neural Information Pro-
cessing Systems (NeurIPS). 2018.

[4] F.X. Diebold. Time-Series Econometrics. Department of Economics, University


of Pennsylvania, 2018.

[5] Lea Duncker, Gergo Bohner, Julien Boussard, and Maneesh Sahani. Learning
interpretable continuous-time models of latent stochastic dynamical systems. In
International Conference on Machine Learning (ICML), 2019.

[6] Chenyou Fan, Yuze Zhang, Yi Pan, Xiaoyue Li, Chi Zhang, Rong Yuan, Di Wu,
Wensheng Wang, Jian Pei, and Heng Huang. Multi-horizon time series forecast-
ing with temporal attention learning. In Proceedings of the 25th ACM SIGKDD
International Conference on Knowledge Discovery & Data Mining, KDD ’19,
2019.

[7] Gonzalo Farias, Sebastián Dormido-Canto, Jesús Vega, Giuseppe Rattá, Héctor
Vargas, Gabriel Hermosilla, Luis Alfaro, and Agustı́n Valencia. Automatic fea-
ture extraction in large fusion databases by using deep learning approach. Fusion
Engineering and Design, 112:979 – 983, 2016.

[8] Shihao Gu, Bryan T. Kelly, and Dacheng Xiu. Empirical asset pricing via ma-
chine learning. Chicago Booth Research Paper No. 18-04; 31st Australasian Fi-
nance and Banking Conference 2018, 2017.

138
[9] Kurt Hornik, Maxwell Stinchcombe, and Halbert White. Multilayer feedforward
networks are universal approximators. Neural Networks, 2(5):359 – 366, 1989.

[10] Junteng Jia and Austin R Benson. Neural jump stochastic differential equations.
In Advances in Neural Information Processing Systems (NeurIPS). 2019.

[11] J. F. Kolen and S. C. Kremer. Gradient Flow in Recurrent Nets: The Difficulty
of Learning Long Term Dependencies. IEEE, 2001.

[12] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, and David K. Du-
venaud. Scalable gradients and variational inference for stochastic differential
equations. In Proceedings of The 2nd Symposium on Advances in Approximate
Bayesian Inference, 2020.

[13] Bryan Lim, Ahmed Alaa, and Mihaela van der Schaar. Forecasting treatment
responses over time using Recurrent Marginal Structural Networks. In NeurIPS,
2018.

[14] Bryan Lim, Sercan O. Arik, Nicolas Loeff, and Tomas Pfister. Temporal Fusion
Transformers for Interpretable Multi-horizon Time Series Forecasting. arXiv
e-prints, page arXiv:1912.09363, 2019.

[15] Bryan Lim and Mihaela van der Schaar. Disease-atlas: Navigating disease trajec-
tories using deep learning. In Proceedings of the Machine Learning for Healthcare
Conference (MLHC), 2018.

[16] Bryan Lim, Stefan Zohren, and Stephen Roberts. Enhancing time-series mo-
mentum strategies using deep neural networks. The Journal of Financial Data
Science, 2019.

[17] Bryan Lim, Stefan Zohren, and Stephen Roberts. Detecting Changes in Asset
Co-Movement Using the Autoencoder Reconstruction Ratio. arXiv e-prints, page
arXiv:2002.02008, January 2020.

[18] Bryan Lim, Stefan Zohren, and Stephen Roberts. Recurrent Neural Filters:
Learning independent bayesian filtering steps for time series prediction. In In-
ternational Joint Conference on Neural Networks (IJCNN). 2020.

[19] Spyros Makridakis, Evangelos Spiliotis, and Vassilios Assimakopoulos. Statistical


and machine learning forecasting methods: Concerns and ways forward. PLOS
ONE, 13(3):1–26, 03 2018.

139
[20] Stefano Massaroli, Michael Poli, Jinkyoo Park, Atsushi Yamashita, and Hajime
Asama. Dissecting Neural ODEs. arXiv e-prints, page arXiv:2002.08071, 2020.

[21] Diaa Noureldin, Neil Shephard, and Kevin Sheppard. Multivariate high-
frequency-based volatility (heavy) models. Journal of Applied Econometrics,
27(6):907–933.

[22] Syama Sundar Rangapuram, Matthias W Seeger, Jan Gasthaus, Lorenzo Stella,
Yuyang Wang, and Tim Januschowski. Deep state space models for time series
forecasting. In Advances in Neural Information Processing Systems (NIPS), 2018.

[23] Ali Sharif Razavian, Hossein Azizpour, Josephine Sullivan, and Stefan Carlsson.
Cnn features off-the-shelf: An astounding baseline for recognition. In Proceed-
ings of the 2014 IEEE Conference on Computer Vision and Pattern Recognition
Workshops, CVPRW ’14, 2014.

[24] David Salinas, Valentin Flunkert, and Jan Gasthaus. DeepAR: Probabilis-
tic Forecasting with Autoregressive Recurrent Networks. arXiv e-prints, page
arXiv:1704.04110, 2017.

[25] Neil Shephard and Kevin Sheppard. Realising the future: forecasting with high-
frequency-based volatility (heavy) models. Journal of Applied Econometrics,
25(2):197–231, 2010.

[26] Kevn Sheppard and Wen Xu. Factor high-frequency based volatility (heavy)
models. SSRN, 2014.

[27] David Silver, Julian Schrittwieser, Karen Simonyan, Ioannis Antonoglou, Aja
Huang, Arthur Guez, Thomas Hubert, Lucas Baker, Matthew Lai, Adrian
Bolton, Yutian Chen, Timothy Lillicrap, Fan Hui, Laurent Sifre, George van den
Driessche, Thore Graepel, and Demis Hassabis. Mastering the game of Go with-
out human knowledge. Nature, 550:354–, 2017.

[28] Aäron van den Oord, Sander Dieleman, Heiga Zen, Karen Simonyan, Oriol
Vinyals, Alex Graves, Nal Kalchbrenner, Andrew W. Senior, and Koray
Kavukcuoglu. WaveNet: A generative model for raw audio. CoRR,
abs/1609.03499, 2016.

[29] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
Aidan N Gomez, L ukasz Kaiser, and Illia Polosukhin. Attention is all you need.
In Advances in Neural Information Processing Systems 30. 2017.

140
[30] Lai Z. and Deng H. Medical image classification based on deep features extracted
by deep model and statistic feature fusion with multilayer perceptron. Comput
Intell Neurosci, Sep 2018.

141

You might also like