Diffusion Actor-Critic: Formulating Constrained Policy Iteration as Diffusion Noise Regression for Offline Reinforcement Learning.
This repository is the official implementation of Diffusion Actor-Critic: Formulating Constrained Policy Iteration as Diffusion Noise Regression for Offline Reinforcement Learning. Our implementation is built upon JAX and jaxrl.
To install requirements:
pip install -r requirements.txt
When running the code for the first time, it is necessary to download the D4RL dataset. As for installation of the Mujoco task environment, please refer to the guidelines.
We train our algorithm for 2 million gradient steps in order to ensure model convergence.
For each environment, we carry out 8 independent training processes.
To train DAC using dual gradient ascent, run the code with eta_lr > 0
and bc_threshold > 0
.
Here is an example for training on the walker2d-medium-v2
dataset:
python main.py --env walker2d-medium-v2 --agent dac --eta 1 --eta_lr 0.001 --bc_threshold 1 --rho 1 --q_tar lcb --num_seed 8 --gpu 0
If using a fixed eta, run the code with eta_lr = 0
and eta > 0
.
Here is an example for the training on the antmaze-umaze-v0
dataset:
python main.py --env antmaze-umaze-v0 --agent dac --maxQ --eta 0.1 --eta_lr 0 --rho 1 --q_tar lcb --num_seed 8 --gpu 0
The evaluation process is integrated into the training process,
with each process evaluating performance using 10 different seeds at intervals of 10,000 gradient steps. It's optional
to save the trained models by adding --save_ckpt
to the code. The training and evaluation results are saved into the 'results/{env name}' folder.
To check the training statistics, you can use the tensorboard:
tensorboard --logdir 'results/{env name}/{results_folder}'
Here is an example that demonstrates the visualization of the training process after executing the code on walker2d-medium-v2
, as shown previously:
tensorboard --logdir results/walker2d-medium-v2/DAC_b=1.0|QTar=lcb|rho=1.0
Our proposed algorithm DAC achieves the following normalized scores on D4RL Dataset:
Dataset | DAC |
---|---|
halfcheetah-m | 59.1 |
hopper-m | 101.2 |
walker2d-m | 96.8 |
halfcheetah-m-r | 55.0 |
hopper-m-r | 103.1 |
walker2d-m-r | 96.8 |
halfcheetah-m-e | 99.1 |
hopper-m-e | 111.7 |
walker2d-m-e | 113.6 |
antmaze-u | 99.5 |
antmaze-u-div | 85.0 |
antmaze-m-play | 85.8 |
antmaze-m-div | 84.0 |
antmaze-l-play | 50.3 |
antmaze-l-div | 55.3 |
Training curves:
Locomotion-v2 | Antmaze-v0 |
---|---|
Please cite this paper if you use this repo
@misc{fang2024diffusion,
title={Diffusion Actor-Critic: Formulating Constrained Policy Iteration as Diffusion Noise Regression for Offline Reinforcement Learning},
author={Linjiajie Fang and Ruoxue Liu and Jing Zhang and Wenjia Wang and Bing-Yi Jing},
year={2024},
eprint={2405.20555},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Please also cite JAXRL as well
@misc{jaxrl,
author = {Kostrikov, Ilya},
doi = {10.5281/zenodo.5535154},
month = {10},
title = {{JAXRL: Implementations of Reinforcement Learning algorithms in JAX}},
url = {https://github.com/ikostrikov/jaxrl},
year = {2021}
}