CompVis @ LMU Munich, MCML
ICCV 2025
Flow Poke Transformer (FPT) directly models the uncertainty of the world by predicting distributions of how objects (×) may move conditioned on some input movements (pokes, →). We see that whether the hand (below paw) or the paw (above hand) moves downwards directly influences the other's movement. Left: the paw pushing the hand down, will force the hand downwards, resulting in a unimodal distribution. Right: the hand moving down results in two modes, the paw following along or staying put.
This codebase is a minimal PyTorch implementation covering training & various inference settings.
The easiest way to try FPT is via our interactive demo, which you can launch as:
python -m scripts.demo.app --compile True --warmup_compiled_paths True
Compilation is optional, but recommended for a better time using the UI. A checkpoint will be downloaded from huggingface by default if not explicitly specified via the CLI.
When using it yourself, the simplest way to use it is via torch.hub
:
model = torch.hub.load("CompVis/flow_poke_transformer", "fpt_base")
If you want to completely integrate FPT into your own codebase, copy model.py
and dinov2.py
to your codebase and you should effectively be good to go. Then instantiate the model as
model: FlowPokeTransformer = FlowPokeTransformer_Base()
state_dict = torch.load("fpt_base.pt")
model.load_state_dict(state_dict)
model.requires_grad_(False)
model.eval()
The FlowPokeTransformer
class contains all the methods that you should need to use FPT in various applications. For high-level usage, use the FlowPokeTransformer.predict_*()
methods. For low-level usage, the module's forward()
can be used.
The only dependencies you should need are a recent torch
(to enable flex attention, although it would be plausible to patch it out with some effort to enable usage of lower torch version), and any einops
, tqdm
, and jaxtyping
(dependency can be removed by deleting type hints) versions.
Code files are separated into major blocks with extensive comments explaining relevant choices, details, and conventions.
For all public-facing APIs involving tensors, type hints with jaxtyping
are provided, which might look like this: img: Float[torch.Tensor, "b c h w"]
. They annotate the dtype (Float
), tensor type torch.Tensor
, and shape b c h w
, and should (hopefully) make the code fully self-explanatory.
Coordinate & Image Conventions.
We represent coordinates in (x, y) order with image coordinates normalized in Attention & RoPE Utilities
section in model.py
for further details
Data Preprocessing. For data preprocessing instructions, please refer to the corresponding readme.
Launching Training. Single-GPU training can be launched via
python train.py --data_tar_base /path/to/preprocessed/shards --out_dir output/test --compile True
Similarly, multi-GPU training, e.g., on 2 GPUs, can be launched using torchrun:
torchrun --nnodes 1 --nproc-per-node 2 train.py [...]
Training can be continued from a previous checkpoint by specifying, e.g., --load_checkpoint output/test/checkpoints/checkpoint_0100000.pt
.
Remove --compile True
for significantly faster startup time at the cost of slower training & significantly increased VRAM usage.
For a full list of available arguments, refer to train.train()
method. We use fire
, such that every argument to the main train function is directly available as a CLI argument.
We release the weights of our open-set model via huggingface at https://huggingface.co/CompVis (under the CC BY-NC 4.0 license), and will potentially release further variants (scaled up or with other improvements). Due to copyright concerns surrounding the WebVid dataset, will not distribute the model weights for the model trained on it. Both models perform approximately equally (see Tab. 1 in the paper), although this will vary on a case-by-case basis due to the different training data.
- Some model code is adapted from k-diffusion by Katherine Crowson (MIT)
- The DINOv2 code is adapted from minDinoV2 by Simo Ryu, which is in turn adapted from the official implementation by Oquab et al. (Apache 2.0)
If you find our model or code useful, please cite our paper:
@inproceedings{baumann2025whatif,
title={What If: Understanding Motion Through Sparse Interactions},
author={Stefan Andreas Baumann and Nick Stracke and Timy Phan and Bj{\"o}rn Ommer},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2025}
}