Skip to content

[ICCV 2025] PyTorch code for the paper "Disentangled World Models: Learning to Transfer Semantic Knowledge from Distracting Videos for Reinforcement Learning"

Notifications You must be signed in to change notification settings

qiwang067/DisWM

Repository files navigation

[ICCV 2025] Disentangled World Models: Learning to Transfer Semantic Knowledge from Distracting Videos for Reinforcement Learning

Qi Wang* · Zhipeng Zhang* · Baao Xie* · Xin Jin · Yunbo Wang · Shiyu Wang · Liaomo Zheng · Xiaokang Yang · Wenjun Zeng

Paper PDF   Project Page   Datasets   Visitors

⚡ Quick Start | 📥 Datasets Download | 📝 Citation

Overview

Training visual reinforcement learning (RL) in practical scenarios presents a significant challenge, i.e., RL agents suffer from low sample efficiency in environments with variations. While various approaches have attempted to alleviate this issue by disentanglement representation learning, these methods usually start learning from scratch without prior knowledge of the world. This paper, in contrast, tries to learn and understand underlying semantic variations from distracting videos via offline-to-online latent distillation and flexible disentanglement constraints. To enable effective cross-domain semantic knowledge transfer, we introduce an interpretable model-based RL framework, dubbed Disentangled World Models (DisWM). Specifically, we pretrain the action-free video prediction model offline with disentanglement regularization to extract semantic knowledge from distracting videos. The disentanglement capability of the pretrained model is then transferred to the world model through latent distillation. For finetuning in the online environment, we exploit the knowledge from the pretrained model and introduce a disentanglement constraint to the world model. During the adaptation phase, the incorporation of actions and rewards from online environment interactions enriches the diversity of the data, which in turn strengthens the disentangled representation learning. Experimental results validate the superiority of our approach on various benchmarks.

Showcases

DMC Humanoid Walk

DisWM TD-MPC2 ContextWM

Drawerworld Open

DisWM TD-MPC2 ContextWM

evaluation_results

Quick Start

DisWM is implemented and tested on Ubuntu 22.04 with python == 3.9, PyTorch == 1.8.1:

  1. Create an environment
conda create -n diswm python=3.9
conda activate diswm
  1. Install dependencies
pip install -r requirements.txt
  1. Install Distracting Control Suite and dmc2gym. The distracting_control folder contains the Distracting Control Suite code with modification to create disjoint colour sets. The dmc2gym folder contains the dmc2gym code revised to use the distracting_control wrappers.

  2. Install Drawerworld.

cd env/drawerworld
pip install -e .
  1. Collect distracting video datasets with DreamerV2 on DMC/MuJoCo Pusher.

  2. We provide two options for recording data during the training process: TensorBoard and Weights & Biases (wandb).

Train DisWM on DMC / MuJoCo Pusher / Drawerworld

  1. Pretrain the video prediction model with collected videos on DMC:
python dreamer.py --configs defaults dmc2gym \
    --device 'cuda:0' --task dmc2gym_reacher_easy \
    --logdir $log_directory \
    --seed 0 --beta_vae_pretrain True --beta_vae True \
    --pretrain_action_num $source_task_action_number \
    --pretrain_datasets_path $dataset_directory/train_eps

Put pretrained checkpoints into checkpoints folder.

  1. (a) Finetune the disentangled world model on DMC Walker Walk:
python dreamer.py --device 'cuda:0' --task dmc2gym_walker_walk \
    --logdir $log_directory \
    --pretrain_checkpoint_path ./checkpoints \
    --seed 0 --configs defaults dmc2gym --traverse True --beta_vae True \
    --method_name 'diswm' --cross_domain False \
    --source_action_num $source_task_action_number --action_num_gap $action_num_gap \
    --distillation True --color_distractor True
  1. (b) Finetune the disentangled world model on MuJoCo Pusher:
python dreamer.py --device 'cuda:0' --task gymnasium_Pusher-v5 \
    --logdir $log_directory \
    --pretrain_checkpoint_path ./checkpoints \
    --seed 0 --configs defaults gymnasium --traverse True --beta_vae True \
    --method_name 'diswm' --cross_domain True \
    --source_action_num $source_task_action_number --action_num_gap $action_num_gap \
    --distillation True --color_distractor True
  1. (c) Finetune the disentangled world model on Drawerworld Open:
python dreamer.py --device 'cuda:0' --task metaworld_drawer-open-v1 \
    --logdir $log_directory \
    --pretrain_checkpoint_path ./checkpoints \
    --seed 0 --configs defaults metawolrd --traverse True --beta_vae True \
    --method_name 'diswm' --cross_domain True \
    --source_action_num $source_task_action_number --action_num_gap $action_num_gap \
    --distillation True --color_distractor True

Datasets Download

We provide a dataset collected from three different tasks (Cheetah Run, Finger Spin, Reacher Easy) in DMC for pretrain DisWM.

Task Name File Name
Cheetah Run cheetah_run.tar.xz
Finger Spin finger_spin.tar.xz
Reacher Easy reacher_easy.tar.xz
Walker Walk walker_walk.tar.xz

Citation

If you find this repo useful, please cite our paper:

@inproceedings{wang2025disentangled,
    title={Disentangled World Models: Learning to Transfer Semantic Knowledge from Distracting Videos for Reinforcement Learning}, 
    author={Qi Wang and Zhipeng Zhang and Baao Xie and Xin Jin and Yunbo Wang and Shiyu Wang and Liaomo Zheng and Xiaokang Yang and Wenjun Zeng},
    booktitle={ICCV},
    year={2025}
}

Acknowledgement

The codes refer to the implemention of dreamer-torch. Thanks for the authors!

About

[ICCV 2025] PyTorch code for the paper "Disentangled World Models: Learning to Transfer Semantic Knowledge from Distracting Videos for Reinforcement Learning"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published