Skip to content

MyRepositories-hub/Mutual-Distillation-Policy-Optimization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Representation Convergence:
Mutual Distillation is Secretly a Form of Regularization

Zhengpeng Xie*, Jiahang Cao*, Qiang Zhang, Jianxiong Zhang, Changwei Wang, Renjing Xu

arXiv

Main Results

Caption: (Left) Independently trained reinforcement learning policies may overfit to spurious features. (Right) Through mutual distillation via DML, two policies regularize each other to converge toward a more robust hypothesis space, ultimately improving generalization performance.

Caption: Generalization performance from 500 levels in Procgen benchmark with different methods. The mean and standard deviation are shown across 3 seeds. Our MDPO gains significant performance improvement compared with the baseline algorithms.

Caption: The feature extraction of MDPO is highly stable and focused (red points), whereas the features extracted by the original PPO encoder are significantly dispersed (blue points).

Installation

To ensure the reproducibility of our main results, please follow the steps below to install the dependencies.

Create Anaconda environment:

conda create -n procgen_py310 python=3.10 --yes
conda activate procgen_py310

Install the requirements:

pip install -r requirements.txt

Choose the CUDA version on the official PyTorch website: https://pytorch.org/

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Train MDPO:

python main.py

Train PPO:

python ppo.py

Acknowledgement

The code is based on cleanrl. The implementation of Mutual Distillation Policy Optimization (MDPO) is divided into multiple components to enhance readability.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages