Reinforcement Learning Model Training and Prediction for Neuromatch group project: Relevance of sensory modalities for spatial navigation in foraging behaviours of the bee
This repository contains Python scripts and utilities for training a custom Reinforcement Learning (RL) model using the Twin Delayed Deep Deterministic Policy Gradient (TD3) algorithm. The trained model can be used for predicting actions in a custom Gym environment.
The repository contains the following files:
bee.py
: Environment (gym) setup - agents action, rewards and space definition for the RL model.train_model.py
: The main script to train the RL model.render_model.py
: A script to generate and display a video of the RL model's predictions.config.yaml
: The configuration file for training the RL model.model.py
: A Python module containing utility functions for initializing and loading the RL model.utils.py
: A Python module containing utility functions for creating directories, saving the configuration, and more.gym_run.py
: Demonstration file for testing gym changes
- Configure the
config.yaml
file with the desired parameters for training the RL model. - Run the training script:
python train_model.py --config_path config.yaml
- After training, you can generate a video of the model's predictions using the
render_model.py
script:
python render_model.py --config_path config.yaml
- You can also use the convinience jupyter-notebooks, useful for working in google colab.
Implemented features:
- logging, early stopping, model managment
- model re-training
- walls/obstacles
- Testing different architectures
- Make goal smaller
- Reward testing
- Adding senses
- Time passed punishment
- wall hitting punishment
- Increasing/gradually shrinking the goal size whilst training
- multiple goals
- Stable Baselines3: https://github.com/DLR-RM/stable-baselines3
- Gymnasium: https://github.com/openai/gym
This project is licensed under the MIT License - see the LICENSE
file for details.