This repository features a PyTorch based implementation of PPO using a recurrent policy supporting truncated backpropagation through time. Its intention is to provide a clean baseline/reference implementation on how to successfully employ recurrent neural networks alongside PPO and similar policy gradient algorithms.
We also offer a clean TransformerXL + PPO baseline repository.
- Added support for Memory Gym
- Added yaml configs
- Added max grad norm hyperparameter to the config
- Gymnasium is used instead of gym
- Only model inputs are padded now
- Buffer tensors are freed from memory after optimization
- Fixed dynamic sequence length
- Recurrent Policy
- GRU
- LSTM
- Truncated BPTT
- Environments
- Proof-of-concept Memory Task (PocMemoryEnv)
- CartPole
- Masked velocity
- Minigrid Memory
- Visual Observation Space: 3x84x84
- Egocentric Agent View Size: 3x3 (default 7x7)
- Action Space: forward, rotate left, rotate right
- MemoryGym
- Mortar Mayhem
- Mystery Path
- Searing Spotlights (WIP)
- Tensorboard
- Enjoy (watch a trained agent play)
@inproceedings{
pleines2023memory,
title={Memory Gym: Partially Observable Challenges to Memory-Based Agents},
author={Marco Pleines and Matthias Pallasch and Frank Zimmer and Mike Preuss},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=jHc8dCx6DDr}
}
- Installation
- Train a model
- Enjoy a model
- Recurrent Policy
- Hyperparameters (configs.py) - Recurrence - General - Schedules
- Model Architecture
- Add environment
- Tensorboard
- Results
Install PyTorch depending on your platform. We recommend the usage of Anaconda.
Create Anaconda environment:
conda create -n recurrent-ppo python=3.11 --yes
conda activate recurrent-ppo
CPU:
conda install pytorch torchvision torchaudio cpuonly -c pytorch
CUDA:
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
Install the remaining requirements and you are good to go:
pip install -r requirements.txt
The training is launched via the command python train.py
.
Usage:
train.py [options]
train.py --help
Options:
--run-id=<path> Specifies the tag of the tensorboard summaries and the model's filename [default: run].
--cpu Whether to enforce training on the CPU, otherwwise an available GPU will be used. [default: False].
Hyperparameters are configured inside of configs.py
. The to be used config has to be specified inside of train.py
. Once the training is done, the final model will be saved to ./models/run-id.nn
. Training statistics are stored inside the ./summaries
directory.
python train.py --run-id=my-training-run
To watch an agent exploit its trained model, execute the python enjoy.py
command.
Some already trained models can be found inside the models
directory!
Usage:
enjoy.py [options]
enjoy.py --help
Options:
--model=<path> Specifies the path to the trained model [default: ./models/minigrid.nn].
The path to the desired model has to be specified using the --model
flag:
python enjoy.py --model=./models/minigrid.nn
- Training data
- Training data is sampled from the current policy
- Sampled data is split into episodes
- Episodes are split into sequences (based on the
sequence_length
hyperparameter) - Zero padding is applied to retrieve sequences of fixed length
- Recurrent cell states are collected from the beginning of the sequences (truncated bptt)
- Forward pass of the model
- While feeding the model for optimization, the data is flattened to feed an entire batch (faster)
- Before feeding it to the recurrent layer, the data is reshaped to
(num_sequences, sequence_length, data)
- Loss computation
- Zero padded values are masked during the computation of the losses
As a reinforcement learning engineer, one has to have high endurance. Therefore, we are providing some information on the bugs that slowed us down for months.
We observed an exploding value function. This was due to unintentionally feeding None
to the recurrent layer. In this case, PyTorch uses zeros for the hidden states as shown by its source code.
if hx is None:
num_directions = 2 if self.bidirectional else 1
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
h_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, real_hidden_size,
dtype=input.dtype, device=input.device)
c_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (h_zeros, c_zeros)
Training an agent using a sequence length greater than 1 caused the agent to just achieve a performance of a random agent. The issue behind this bug was found in reshaping the data right before feeding it to the recurrent layer. In general, the desire is to feed the entire training batch instead of sequences to the encoder (e.g. convolutional layers). Before feeding the processed batch to the recurrent layer, it has to be rearranged into sequences. At the point of this bug, the recurrent layer was initialized with batch_first=False
. Hence, the data was reshaped using h.reshape(sequence_length, num_sequences, data)
. This messed up the structure of the sequences and ultimately caused this bug. We fixed this by setting batch_first
to True
and therefore reshaping the data by h.reshape(num_sequences, sequence_length, data)
.
Hidden States were not reset
This is rather considered as a feature and not a bug. For environments that produce rather short episodes are likely to take advantage of not resetting the hidden states upon commencing a new episode. This is the case for MinigridMemory-S9. Resetting hidden states is now controlled by the hyperparameter reset_hidden_state
inside configs.py. The actual mistake was the mixed up order of saving the recurrent cell to its respective placeholder and resetting it.
Hyperparameter | Description |
---|---|
sequence_length | Length of the trained sequences, if set to 0 or smaller the sequence length is dynamically fit to episode lengths |
hidden_state_size | Size of the recurrent layer's hidden state |
layer_type | Supported recurrent layers: gru, lstm |
reset_hidden_state | Whether to reset the hidden state upon starting a new episode. This can be beneficial for environments that produce short episodes like MinigridMemory-S9. |
gamma | Discount factor |
lamda | Regularization parameter used when calculating the Generalized Advantage Estimation (GAE) |
updates | Number of cycles that the entire PPO algorithm is being executed |
n_workers | Number of environments that are used to sample training data |
worker_steps | Number of steps an agent samples data in each environment (batch_size = n_workers * worker_steps) |
epochs | Number of times that the whole batch of data is used for optimization using PPO |
n_mini_batch | Number of mini batches that are trained throughout one epoch |
value_loss_coefficient | Multiplier of the value function loss to constrain it |
hidden_layer_size | Number of hidden units in each linear hidden layer |
max_grad_norm | Gradients are clipped by the specified max norm |
These schedules can be used to polynomially decay the learning rate, the entropy bonus coefficient and the clip range.
learning_rate_schedule | The learning rate used by the AdamW optimizer |
beta_schedule | Beta is the entropy bonus coefficient that is used to encourage exploration |
clip_range_schedule | Strength of clipping optimizations done by the PPO algorithm |
The figure above illustrates the model architecture in the case of training Minigrid. The visual observation is processed by 3 convolutional layers. The flattened result is then divided into sequences before feeding it to the recurrent layer. After passing the recurrent layer's result to one hidden layer, the network is split into two streams. One computes the value function and the other one the policy. All layers use the ReLU activation.
In the case of training an environment that utilizes vector observations only, the visual encoder is omitted and the observation is fed directly to the recurrent layer.
Follow these steps to train another environment:
- Extend the create_env() function in utils.py by adding another if-statement that queries the environment's name
- At this point you could simply use gym.make() or use a custom environment that builds on top of the gym interface.
- Adjust the "env" key inside the config dictionary to match the name of the new environment
During training, tensorboard summaries are saved to summaries/run-id/timestamp
.
Run tensorboad --logdir=summaries
to watch the training statistics in your browser using the URL http://localhost:6006/.
The code for plotting the results can be found in the results directory. Results on Memory Gym can be found in our TransformerXL + PPO baseline repository.
(only trained on MinigridMemory-S9 using unlimited seeds)