Skip to content

Re-implementation of pi0 vision-language-action (VLA) model from Physical Intelligence

License

Notifications You must be signed in to change notification settings

allenzren/open-pi-zero

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

76 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

open-pi-zero

This repo implements the pi0 model from Physical Intelligence (Pi) based on my knowledge of the paper.

The model adopts a MoE-like architecture (or the recent MoT, each expert has its own set of parameters and only interacts through attention), and uses a pre-trained 3B PaliGemma VLM (2.291B to be fine-tuned) and a new set of action expert parameters (0.315B). Block-wise causal masking is used such that VLM block attends to itself, proprioception (sharing weights with action) attends to itself and VLM, and action attends to all; each block is fully bidirectional within. The model is trained with flow matching loss on the action chunk output from the action expert.

If you find a bug or think I may have misunderstood part of the architecture based on the paper, please raise an issue or email me.

open-pi-zero-overview

Installation

Clone this repository at your directory. If running Simpler eval or trying out trained checkpoints, clone fork (addded proprio support) to the same directory

git clone https://github.com/allenzren/SimplerEnv --recurse-submodules

Install uv and run the following in the repo directory

uv sync
uv pip install -e . ../SimplerEnv
uv pip install -e . ../SimplerEnv/ManiSkill2_real2sim

Or you may use venv or conda env instead of uv and run pip install -e . in all three directories.

Set environment variables VLA_DATA_DIR (if downloading datasets for training), VLA_LOG_DIR, and VLA_WANDB_ENTITY by running source scripts/set_path.sh

Download PaliGemma weights at TRANSFORMERS_CACHE

git clone https://huggingface.co/google/paligemma-3b-pt-224

Test text generation with pre-trained weights

uv run src/model/vla/pizero.py --text_only --load_pretrained_weights --use_bf16

Try checkpoints

I have only trained with either fractal or bridge dataset (unlabeled skipped) so far (training with mixed OXE data soon). Links to the models: Bridge-Uniform | Bridge-Beta | Fractal-Uniform | Fractal-Beta

Uniform and Beta denotes the mode for sampling flow matching timesteps during training: Uniform samples uniformly between 0 and 1, and Beta, proposed by Pi0, samples with higher density at earlier timesteps.

Run an trial in Simpler after downloading a checkpoint (see the list of tasks in the script)

uv run scripts/try_checkpoint_in_simpler.py \
    --task google_robot_pick_horizontal_coke_can \
    --checkpoint_path ...fractal_beta.pt \
    --recording \
    --use_bf16 \
    --use_torch_compile # first batch will be slow

Training details

The models were trained with learning rate 5e-5, global batch size 1024, and roughly 19k gradient steps with bridge (~12 epochs) and 30k with fractal (~8 epochs). Input to the model includes single image (256 tokens, no history), max 20 text tokens, 1 proprio token (no history), and 4 action tokens (chunk size 4). It took roughly 2-3 days on one L40 node (per-GPU bsz 16 and thus gradient accumulation step 8), or 12-18 hours with H100s (bsz 32). torch.compile, bfloat16, and 8-bit optimizer were used to reduce VRAM usage (peak 40GB with bsz 16). Action and proprioception data were normalized in [-1, 1].

Inference speed and VRAM usage

Inference involves one forward pass through PaliGemma (saving KV cache), and then 10 flow matching steps through the action expert. With RTX 4090:

Setup Time Peak VRAM
float32 237ms 13.6GB
bf16 245ms 6.7GB
float32 + torch.compile 89ms 13.6GB
bf16 + torch.compile 75ms 6.7GB
Pi0 paper 73ms* -

torch.compile mode is set to default. I also tried to use torch_tensorrt but compilation fails silently right now.

*From Table I of the paper. Note that my numbers are with single image input and action chunk size 4, while Pi uses three images and chunk size 50 according to the paper appendix.

Eval results

For both set of environments I tried running all entire action chunk (size 4) or the first 2 steps only. Bridge policies work better with running all, while fractal policies work better with running 2 out 4. Note that bridge data is 5Hz and fractal data is 3Hz.

Success rates in visual matching setting in Simpler (results in visual aggregation setting coming soon)

Policy Dtype Carrot on plate Eggplant in basket Spoon on towel Stack cube
Bridge-Uniform float32 58.8% 79.2% 63.3% 21.3%
^ bf16 58.8% 81.3% 61.7% 23.8%
Bridge-Beta float32 55.8% 85.4% 84.6% 47.9%
^ bf16 52.5% 87.9% 83.8% 52.5%
Policy Dtype Pick up Coke Move Near Close Drawer Open Drawer Open Top Drawer and Put Apple In
Fractal-Uniform float32 88.0% 80.3% 66.7% 45.2% 52.2%
^ bf16 88.9% 80.5% 65.4% 45.3% 53.0%
Fractal-Beta float32 97.9% 78.7% 75.0% 49.5% 46.6%
^ bf16 97.8% 78.4% 74.7% 51.7% 46.1%

All numbers are averaged over 10 trials on top of prepackaged variations (robot/obj locations, URDFs, rgb_overlays) of each task in Simpler (total 240-2400 trials per task --- I see significant variations with <=3 seeds). Also note that these numbers may vary significantly among different checkpoints that are both mostly converged but from different epochs --- training with mixed datasets or using EMA might help.

Reason on evaluating with both bf16 and float32: While the model is trained with bf16, mixed precision, and no KV caching, during inference KV cache of VLM/Proprio is used. This leads to a distribution shift of the policy output when bf16 is used (discussion) compared to not using KV cache, estimated around 5e-4 to 2.5e-3 (out of the [-1, 1] normalization range) in avg L1 distance; difference is negligible when float32 is used.

Disclaimer: Please do not associate the results here with possible results from Pi.

Run training

Download data

Download fractal data (following OXE)

uv run gsutil -m cp -r gs://gresearch/robotics/fractal20220817_data/0.1.0   # at $VLA_DATA_DIR

Download bridge data (bridge_dataset folder) from RAIL link as suggested by OXE.

Data pre-processing

Run slurm/modify_rlds.sh (taken from rlds_dataset_mod), which resizes the images to 224x224 for PaliGemma. Data shall be saved at $VLA_DATA_DIR/resize_224. It takes about 1 hour or so.

Possible error | Conventions on proprio / gripper actions

Training scripts

See examples in the slurm folder. TFDS dataloading takes a growing amount of CPU RAM and roughly peaks at about 300-400GB in a node as each DDP process spawns a dataloader.

For further reducing VRAM usage, you may use (Q)LoRA by setting (quantize=True and) lora=True. However, the training performance may be affected. Or try optimizer offloading.

Discussion on RAM | Possible error if running quantization | Observations / lessons from training

Run evaluation

Simpler

See examples in the slurm folder. Currently they use the dataset statistics generated by my training; you may update env.adapter.dataset_statistics_path in the config to the dataset statistics json file generated in your training, located in the dataset folder.

File structure / naming

I implemented the model in a fairly modular way, so it is straightforward to add or remove any of the mixture models. PiZero class follows the original Pi0 architecture, while the underlying JointModel and Mixture classes form a general instantiation of the MoE + block-attention architecture.

src
 ├── agent
 │     ├── train.py             # training workspace
 │     ├── eval.py              # eval workspace (with Simpler)
 │     └── env_adapter
 │          ├── base.py
 │          └── simpler.py      # env obs pre-processing and action post-processing
 ├── data                       # training data pre-processing, mostly from Octo and dlimp
 └── model
       ├── paligemma            # mostly from open-source paligemma
       │    ├── siglip.py       # SigLIP and projector
       │    └── modules.py      # RoPE, MLP, RMSNorm
       └── vla
            ├── pizero.py       # PiZero: `joint_model` (Gemma and action expert), SigLIP, en/decoders
            ├── joint_model.py  # mixture of experts, run each expert and global attention
            ├── mixture.py      # individual `expert` following PaliGemma layout
            ├── modules.py      # action encoder, time embedding
            └── processing.py   # text tokenization and image normalization for PaliGemma

Things to implement / try

Tune sinusoidal embedding and RoPE parameters to better encode the relatively low number of action tokens and timestep. Multi-image (history) as input. Use EMA. Switch to GPU Simpler. Fine-tuning by adding a new expert (e.g., second camera view into pre-trained DINO/SigLIP) and gradual unmasking. Co-training with (self-)supervision on modalities other than action.

Acknowledgement

PaliGemma setup is largely adopted from Open-source PaliGemma. Dataset loading is adopted from Octo and dlimp. Dataset pre-processing is adopted from rlds_dataset_mod. Other references: Pi0, OpenVLA, Flow matching, Implementation by lucidrains, SimplerEnv, QLoRA-LLM

Special thanks to Asher Hancock for the discussion on block-wise causal masking.

About

Re-implementation of pi0 vision-language-action (VLA) model from Physical Intelligence

Resources

License

Stars

Watchers

Forks