Implementation of "Real-time Policy Distillation in Deep Reinforcement Learning" paper
Python 3.8+ and ray rllib with pytorch backend are two main requirements. Full conda environment specification (and specific library versions) could be found in environment.yml
file.
In file scripts/models.py
policy classes, that encapsulate both teacher and all student networks, are specified. In that file custom loss (which includes Q-loss, both versions of KL losses and imitation Q-loss) and custom evaluation function (which evaluates both teacher and all student networks on each iteration) are defined too.
In file scripts/trainer.py
the trainer is inherited from the ray implementation of the APEX algorithm.
File scripts/main.py
is an entry point to the training process.
File plots.ipynb
contains the reproduction of all tables and figures from the report (And code, that can be used to calculate these values for any other trained model/game).
To run a training process, firstly you should create a config file (three config files, that was used for the project are presented in the folder configs
). It should be in a yaml file format and it is a common ray.rllib
config, so it accepts any field that can be accepted by the ray.rllib
and used to tune the behaviour of the ray.rllib.agents.dqn.ApexTrainer
trainer. Run training with the following command from the root of the repository:
python -m scripts.main --config path_to_config_file
File plots.ipynb
contains the bare minimum of code required to calculate results presented in the report. It works with tensorboard log files that are generated during the training process. We attach tensorboard event files of our experiments, that can be accepted by this link (this files are too large for the git repo. To use them in plots.ipynb
you should download them to your computer and specify the variable EVENT_FILE_PATH
in the second cell of the ipynb file).