Skip to content

SageMaker implementation of recurrent neural networks (RNNs) for time series forecasting.

Notifications You must be signed in to change notification settings

fg-research/rnn-sagemaker

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RNN SageMaker Algorithm

The Time Series Forecasting (RNN) Algorithm from AWS Marketplace performs time series forecasting with a multi-layer Recurrent Neural Network (RNN). It implements both training and inference from CSV data and supports both CPU and GPU instances. The training and inference Docker images were built by extending the PyTorch 2.1.0 Python 3.10 SageMaker containers. The algorithm can be used for both univariate and multivariate time series and supports the inclusion of external features.

Model Description

The model is a stack of RNN layers with either LSTM or GRU cells. Each RNN layer is followed by an activation layer and a dropout layer. The model is trained by minimizing the negative Gaussian log-likelihood and outputs the predicted mean and standard deviation at each future time step. The algorithm can be used for both univariate and multivariate time series and supports the inclusion of external features.

LSTM and GRU cells (source: doi:10.48550/arXiv.1412.3555)

SageMaker Algorithm Description

Training

The training algorithm has two input data channels: training and validation. The training channel is mandatory, while the validation channel is optional.

The training and validation datasets should be provided as CSV files. Each column represents a time series, while each row represents a time step. All the time series should have the same length and should not contain missing values. The column headers should be formatted as follows:

  • The column names of the (mandatory) target time series should start with "y" (e.g. "y1", "y2", ...).
  • The column names of the (optional) feature time series should start with "x" (e.g. "x1", "x2", ...).

If the features are not provided, the algorithm will only use the past values of the target time series as input. The time series are scaled internally by the algorithm, there is no need to scale the time series beforehand.

See the sample input files train.csv and valid.csv. See notebook.ipynb for an example of how to launch a training job.

Distributed Training

The algorithm supports multi-GPU training on a single instance, which is implemented through torch.nn.DataParallel. The algorithm does not support multi-node (or distributed) training across multiple instances.

Incremental Training

The algorithm supports incremental training. The model artifacts generated by a previous training job can be used to continue training the model on the same dataset or to fine-tune the model on a different dataset.

Hyperparameters

The training algorithm takes as input the following hyperparameters:

  • context-length: int. The length of the input sequences.
  • prediction-length: int. The length of the output sequences.
  • sequence-stride: int. The period between consecutive output sequences.
  • cell-type: str. The type of RNN cell used by each layer.
  • hidden-size-1: str. The number of hidden units of the first layer.
  • hidden-size-2: str. The number of hidden units of the second layer.
  • hidden-size-3: str. The number of hidden units of the third layer.
  • activation: str. The activation function applied after each layer, either "silu", "relu", "tanh", "gelu", or "lecun".
  • dropout: float. The dropout rate applied after each layer.
  • lr: float. The learning rate used for training.
  • lr-decay: float. The decay factor applied to the learning rate.
  • batch-size: int. The batch size used for training.
  • epochs: int. The number of training epochs.

Metrics

The training algorithm logs the following metrics:

  • train_mse: float. Training mean squared error.
  • train_mae: float. Training mean absolute error.

If the validation channel is provided, the training algorithm also logs the following additional metrics:

  • valid_mse: float. Validation mean squared error.
  • valid_mae: float. Validation mean absolute error.

See notebook.ipynb for an example of how to launch a hyperparameter tuning job.

Inference

The inference algorithm takes as input a CSV file containing the time series. The CSV file should have the same format and columns as the one used for training. See the sample input file test.csv.

The inference algorithm outputs the predicted values of the time series and the standard deviation of the predictions.

Notes:

a) The model predicts the time series sequence by sequence. For instance, if the context-length is set equal to 200, and the prediction-length is set equal to 100, then the first 200 data points (from 1 to 200) are used as input to predict the next 100 data points (from 201 to 300). As a result, the algorithm does not return the predicted values of the first 200 data points, which are set to missing in the output CSV file.

b) The outputs include the out-of-sample forecasts beyond the last time step of the inputs. For instance, if the number of input samples is 500, and the prediction-length is 100, then the output CSV file will contain 600 samples, where the last 100 samples are the out-of-sample forecasts.

See the sample output files batch_predictions.csv and real_time_predictions.csv. See notebook.ipynb for an example of how to launch a batch transform job.

Endpoints

The algorithm supports only real-time inference endpoints. The inference image is too large to be uploaded to a serverless inference endpoint.

See notebook.ipynb for an example of how to deploy the model to an endpoint, invoke the endpoint and process the response.

Additional Resources: [Sample Notebook] [Blog Post]

References

  • Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.
  • Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using RNN encoder-decoder for statistical machine translation. arXiv preprint. arXiv:1406.1078.
  • Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint. arXiv:1412.3555.