Skip to content

Max-Fu/tinydp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TinyDP

A light weight adaptation of Diffusion Policy used in Real2Render2Real.

DISCLAIMER: A lot of the tangential code were taken out from the original repo to keep the code light weight. There may be quite a few issues / bugs with the code. Please report them if you find any.

Installation

# we use python=3.10.15
conda create -n tinydp python=3.10.15
conda activate tinydp
pip install -e .

Data Conversion

To optimize for training throughput, the h5 files generated in the real2render2real repo are converted to zarr files and jpg images:

python script/convert_dataset_to_mp4.py --root-dir /PATH/TO/DATASET 

Training

We provide example training scripts below:

EPOCHS=200
DATASET_ROOT=/PATH/TO/DATASET
LOG_NAME=YOUR_LOG_NAME
OUTPUT_DIR=./dpgs_checkpoints
python script/train.py --dataset-cfg.dataset-root $DATASET_ROOT --logging-cfg.log-name $LOG_NAME --logging-cfg.output-dir $OUTPUT_DIR --trainer-cfg.epochs $EPOCHS

The code can be run on multi-GPU via DDP

MASTERPORT=2222
CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node=2 --master_port=$MASTERPORT script/train.py --dataset-cfg.dataset-root $DATASET_ROOT --logging-cfg.log-name $LOG_NAME --logging-cfg.output-dir $OUTPUT_DIR --trainer-cfg.epochs $EPOCHS

Please see other knobs you can tune via

python script/train.py --help

Important Flags:

If you are using sim data, please add the following flag:

--dataset-cfg.is-sim-data

If only one arm in a bimanual setup is used, please add ONE of the following flags:

--model-cfg.policy-cfg.pred-left-only
--model-cfg.policy-cfg.pred-right-only

If you want to subsample the data, you can add the following flag:

SUBSAMPLE_NUM_TRAJ=100
--dataset-cfg.data-subsample-num-traj $SUBSAMPLE_NUM_TRAJ

Evaluation

We provide an example snippet below:

from tinydp.policy.diffusion_wrapper import DiffusionWrapper
model_ckpt_folder = "/PATH/TO/MODEL/CHECKPOINT"
ckpt_id = 50
device = "cuda"
inferencer = DiffusionWrapper(model_ckpt_folder, ckpt_id, device=device)

while True:
    # nbatch: Batch dictionary containing:
    # - observation: torch.tensor: Images of shape (B, T, num_cameras, C, H, W)
    # - proprio: torch.tensor: Proprioceptive data of shape (B, T, D)
    nbatch = {
        "observation": ...,
        "proprio": ...,
    }
    nbatch = {k: v.to(device) for k, v in nbatch.items()}
    pred_action = inferencer(nbatch) # batch, action_horizon, action_dim (it is denormalized with the statistics)

A more complete example can be found in script/inference.py.

Acknowledgements

A lot of the code are borrowed from Diffusion Policy.

About

Diffusion Policy Used in Real2Render2Real

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages