Boyuan Chen1, Diego Martí Monsó2, Yilun Du1, Max Simchowitz1, Russ Tedrake1, Vincent Sitzmann1
1MIT 2Technical University of Munich
This is the v1.5 code base for our paper Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion. The main branch contains our latest reimplementation with temporal attention (recommended) while the paper branch contains RNN code used by original paper for reproduction purpose.
Diffusion Forcing v2 is coming very soon! There is a stronger technique to achieve infinite, consistent video generation uniquely enabled by diffusion forcing. We are actively investigating that so please stay tuned. We will also release latent diffusion code by then that allows you to scale up to higher resolution / longer videos!
@misc{chen2024diffusionforcingnexttokenprediction,
title={Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion},
author={Boyuan Chen and Diego Marti Monso and Yilun Du and Max Simchowitz and Russ Tedrake and Vincent Sitzmann},
year={2024},
eprint={2407.01392},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.01392},
}
If you want to use our latest improved implementation for video and planning with temporal attention instead of RNN, stay on this branch. If you are instead interested in reproducing claims by orignal paper, switch to the branch used by original paper via git checkout paper
.
Run conda create python=3.10 -n diffusion-forcing
to create environment.
Run conda activate diffusion-forcing
to activate this environment.
Install dependencies for time series, video and robotics:
pip install -r requirements.txt
Sign up a wandb account for cloud logging and checkpointing. In command line, run wandb login
to login.
Then modify the wandb entity in configurations/config.yaml
to your wandb account.
Optionally, if you want to do maze planning, install the following complicated dependencies due to outdated dependencies of d4rl. This involves first installing mujoco 210 and then run
pip install -r extra_requirements.txt
Since dataset is huge, we provide a mini subset and pre-trained checkpoints for you to quickly test out our model! To do so, download mini dataset and checkpoints from here to project root and extract with tar -xzvf quickstart_atten.tar.gz
. Files shall appear in data
and outputs/xxx.ckpt
. Make sure you also git pull upstream to use latest version of code if you forked before ckpt release!
Then run the following commands and go to the wandb panel to see the results.
Our visualization is side by side, with prediction on the left and ground truth on the right. However, ground truth is expected to not align with prediction since the sequence is highly stochastic. Ground truth is provided to provide an idea about quality only.
Autoregressively generate minecraft video with 1x the length it's trained on:
python -m main +name=sample_minecraft_pretrained load=outputs/minecraft.ckpt experiment.tasks=[validation]
To let the model roll out longer than it's trained on, simply append dataset.validation_multiplier=8
to the above commands, and it will rollout 8x
longer than maximum sequence length it's trained on.
The above checkpoint is trained for 100K steps with small number of frames. We've already verified diffusion forcing works in latent diffusion setting and can be extended to many more tokens without sacrificing compositionally (with some addition techniques outside this repo)! Stay tuned for our next project!
The maze planning setting is changed a bit as we gain more insighs, please see corresponding paragraphs in training section for details. We haven't reimplemented MCTG yet, but you can already see nice visualizations on wandb log.
Medium Maze
python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] load=outputs/maze2d_medium_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=3 +name=maze2d_medium_x_sampling
Large Maze
python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] load=outputs/maze2d_large_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_large_x_sampling
We also explored a couple more settings but haven't reimplemented everything in original paper yet. If you are interestted in those checkpoints, see the source code of this README file for ckpt loading instructions that's commented out.
Video prediction requires downloading giant datasets. First, if you downloaded the mini subset following Quick start with pretrained checkpoints
section, delete the mini subset folders data/minecraft
and data/dmlab
because we have to download the whole dataset this time. We've coded in python that it will download the dataset for you it doesn't already exist. Due to the slowness of the source, this may take a couple days. If you prefer to do it yourself via bash script, please refer to the bash scripts in original TECO dataset and use dmlab.sh
and minecraft.sh
in their Dataset section of README, any maybe split bash script into parallel scripts.
Then just run the corresponding commands:
python -m main +name=your_experiment_name algorithm=df_video dataset=video_minecraft
python -m main +name=your_experiment_name algorithm=df_video dataset=video_dmlab algorithm.weight_decay=1e-3 algorithm.diffusion.architecture.network_size=48 algorithm.diffusion.architecture.attn_dim_head=32 algorithm.diffusion.architecture.attn_resolutions=[8,16,32,64] algorithm.diffusion.beta_schedule=cosine
Simply append algorithm.causal=False
to your command.
Please take a look at "Load a checkpoint to eval" paragraph to understand how to use load checkpoint with load=
. Then, run the exact training command with experiment.tasks=[validation] load={wandb_run_id}
to load a checkpoint and experiment with sampling.
To see how you can roll out longer than the sequence is trained on, you can find instructions in quick start with pretrained checkpoints
section. Keep in mind that rolling out infinitely without sliding window is a property of original RNN implementation on paper
branch, and this version has to use sliding window since it's temporal attention.
By default, we run autoregressive sampling with stablization. To sample next 2 tokens jointly, you can append the following to the above command: algorithm.scheduling_matrix=full_sequence algorithm.chunk_size=2
.
For those who only wish to reproduce the original paper instead of transformer architecture, please checkoutpaper
branch of the code instead.
Medium Maze
python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] +name=maze2d_medium_x
Large Maze
python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] +name=maze2d_large_x
Run planning after model is trained
Please take a look at "Load a checkpoint to eval" paragraph to understand how to use load checkpoint with load=
. To sample, simply append load={wandb_id_of_above_runs} experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_sampling
to above command after trained. Feel free to tune the guidance_scale
from 1 - 5.
This version of maze planning uses a different version of diffusion forcing from original paper - while doing the follow up to diffusion forcing, we realized that training with independent noise actually constructed a smooth interpolation between causal and non-causal models too, since we can just masked out future by complete noise (fully causal) or some noise (interpolation). The best thing is, you can still account for causal uncertainty via pyramoid sampling in this setting, by masking out tokens at different noise levels, and you can still have flexible horizon because you can tell the model that padded entries are pure noise, a unique ability of diffusion forcing.
We also reflected a bit about the environment and concluded that the original metric isn't necessarily a good metric, because maze planning should reward those who can plan the fastest route to goal, not a slow walking agent that goes there at the end of episode. The dataset never contains data of staying at the goal, so agents are supposed to walk away after reaching the goal. I think Diffuser had an unfair advantage of just generating slow plans, that happend to let the agent stay in the neighbour hood of goal for longer and got very high reward, exploiting flaws in the environment design (a good design would involve penalty of longer time taken to reach goal). So, in this version of code, we just optimize for flexible horizon planning that tries to reach goal asap, and the planner will automatically come back to goal if it left the goal since staying is never in dataset. You can see new metrics we designed in wandb logging interface.
Please checkout paper
branch for the code used by original paper. If I have time later, I will reimplement these two domains with transformer as well to complete this branch.
Data | Notes |
---|---|
Jul/30/24 | Upgrade RNN to temporal attention, move orignal code to 'paper' branch |
Jul/03/24 | Initial release of the code. Email me if you have questions or find any errors in this version. |
This repo is forked from Boyuan Chen's research template repo. By its MIT license, you must keep the above sentence in README.md
and the LICENSE
file to credit the author.
All experiments can be launched via python -m main +name=xxxx {options}
where you can fine more details later in this article.
The code base will automatically use cuda or your Macbook M1 GPU when available.
For slurm clusters e.g. mit supercloud, you can run python -m main cluster=mit_supercloud {options}
on login node.
It will automatically generate slurm scripts and run them for you on a compute node. Even if compute nodes are offline,
the script will still automatically sync wandb logging to cloud with <1min latency. It's also easy to add your own slurm
by following the Add slurm clusters
section.
First, create a new repository with this template. Make sure the new repository has the name you want to use for wandb logging.
Add your method and baselines in algorithms
following the algorithms/README.md
as well as the example code in
algorithms/diffusion_forcing/df_video.py
. For pytorch experiments, write your algorithm as a pytorch lightning
pl.LightningModule
which has extensive
documentation. For a quick start, read "Define a LightningModule" in this link. Finally, add a yaml config file to configurations/algorithm
imitating that of configurations/algorithm/df_video.yaml
, for each algorithm you added.
Add your dataset in datasets
following the datasets/README.md
as well as the example code in
datasets/video
. Finally, add a yaml config file to configurations/dataset
imitating that of
configurations/dataset/video_dmlab.yaml
, for each dataset you added.
Add your experiment in experiments
following the experiments/README.md
or following the example code in
experiments/exp_video.py
. Then register your experiment in experiments/__init__.py
.
Finally, add a yaml config file to configurations/experiment
imitating that of
configurations/experiment/exp_video.yaml
, for each experiment you added.
Modify configurations/config.yaml
to set algorithm
to the yaml file you want to use in configurations/algorithm
;
set experiment
to the yaml file you want to use in configurations/experiment
; set dataset
to the yaml file you
want to use in configurations/dataset
, or to null
if no dataset is needed; Notice the fields should not contain the
.yaml
suffix.
You are all set!
cd
into your project root. Now you can launch your new experiment with python main.py +name=<name_your_experiment>
. You can run baselines or
different datasets by add arguments like algorithm=xxx
or dataset=xxx
. You can also override any yaml
configurations by following the next section.
One special note, if your want to define a new task for your experiment, (e.g. other than training
and test
) you can define it as a method in your experiment class and use experiment.tasks=[task_name]
to run it. Let's say you have a generate_dataset
task before the task training
and you implemented it in experiment class, you can then run python -m main +name xxxx experiment.tasks=[generate_dataset,training]
to execute it before training.
We use hydra instead of argparse
to configure arguments at every code level. You can both write a static config in configuration
folder or, at runtime,
override part of yur static config with command line arguments.
For example, arguments algorithm=example_classifier experiment.lr=1e-3
will override the lr
variable in configurations/experiment/example_classifier.yaml
. The argument wandb.mode
will override the mode
under wandb
namesspace in the file configurations/config.yaml
.
All static config and runtime override will be logged to cloud automatically.
For machine learning experiments, all checkpoints and logs are logged to cloud automatically so you can resume them on another server. Simply append resume={wandb_run_id}
to your command line arguments to resume it. The run_id can be founded in a url of a wandb run in wandb dashboard. By default, latest checkpoint in a run is stored indefinitely and earlier checkpoints in the run will be deleted after 5 days to save your storage.
On the other hand, sometimes you may want to start a new run with different run id but still load a prior ckpt. This can be done by setting the load={wandb_run_id / ckpt path}
flag.
The argument experiment.tasks=[task_name1,task_name2]
(note the []
brackets here needed) allows to select a sequence of tasks to execute, such as training
, validation
and test
. Therefore, for testing a machine learning ckpt, you may run python -m main load={your_wandb_run_id} experiment.tasks=[test]
.
More generally, the task names are the corresponding method names of your experiment class. For BaseLightningExperiment
, we already defined three methods training
, validation
and test
for you, but you can also define your own tasks by creating methods to your experiment class under intended task names.
We provide a useful debug flag which you can enable by python main.py debug=True
. This will enable numerical error tracking as well as setting cfg.debug
to True
for your experiments, algorithms and datasets class. However, this debug flag will make ML code very slow as it automatically tracks all parameter / gradients!
It's very easy to add your own slurm clusters via adding a yaml file in configurations/cluster
. You can take a look
at configurations/cluster/mit_vision.yaml
for example.