A lean, optimized implementation of spatiotemporal masked autoencoders (ST-MAEs). The skeleton of the code is recycled from Facebook's ST-MAE repository with various simplifications. The following optimizations are implemented:
- FlashAttention-2
-
torch.compile
-
fused
AdamW - mixed precision training (
torch.cuda.amp
) -
DDP
for distributed training - selective decoding of videos
These optimizations allow us to achieve a very high training throughput: e.g. on merely 4 H100 GPUs, in roughly 1 week, we were able to complete over 160 epochs of training on Kinetics-700 (~536K videos) with a ViT-H encoder (~633M parameters) with 8x16x16=2048 spatiotemporal input "tokens" (i.e. 8 tokens in the temporal dimension and 16x16 tokens in the spatial dimensions) with a masking ratio of 90% and an effective batch size of 256 videos (64 videos on each GPU).
Dependence of model definitions on the timm
library is also removed in this implementation, so the code is self-contained except for the standard libraries. The code was tested with pytorch==2.2.0
and torchvision==0.17.0
.
- Training: To train a spatiotemporal MAE model with a ViT-H/14 architecture from scratch on your data, use
pretrain.py
, e.g.:
python -u pretrain.py \
--data_dirs DATA_DIRS \
--datafile_dir DATAFILE_DIR \
--save_prefix INFORMATIVE_SAVE_PREFIX \
--output_dir OUTPUT_DIR \
--model 'mae_vit_huge_patch14' \
--batch_size_per_gpu 1 \
--accum_iter 1 \
--epochs 100000 \
--num_frames 16 \
--img_size 224 \
--decoder_embed_dim 512 \
--decoder_depth 4 \
--pin_mem \
--t_patch_size 2 \
--repeat_aug 16 \
--sampling_rate 8 \
--lr 0.0001 \
--weight_decay 0.05 \
--mask_ratio 0.9 \
--pred_t_dim 16 \
--clip_grad 0.1
Here, DATA_DIRS
is a list of directories containing the video files, DATAFILE_DIR
is the directory where a .csv
file containing all the training video file paths (optionally, with the corresponding class labels) will be saved, and OUTPUT_DIR
is the directory where the checkpoints and training logs will be saved.
- Finetuning on videos: To finetune a ViT-H/14 model on a downstream video recognition task, use
finetune.py
, e.g.:
python -u finetune.py \
--train_dir TRAIN_DIR \
--val_dir VAL_DIR \
--datafile_dir DATAFILE_DIR \
--save_prefix INFORMATIVE_SAVE_PREFIX \
--output_dir OUTPUT_DIR \
--finetune SPATIOTEMPORAL_MAE_CHECKPOINT \
--num_classes 174 \
--model 'vit_huge_patch14' \
--batch_size_per_gpu 4 \
--accum_iter 1 \
--epochs 100000 \
--num_frames 16 \
--input_size 224 \
--pin_mem \
--t_patch_size 2 \
--repeat_aug 1 \
--sampling_rate 8 \
--blr 0.0024 \
--clip_grad 5.0 \
--mixup 0 \
--cutmix 0.0
Here, TRAIN_DIR
and VAL_DIR
are the directories containing the training and validation videos, respectively, and SPATIOTEMPORAL_MAE_CHECKPOINT
is the path to the pretrained spatiotemporal MAE checkpoint the model is initialized with (use ""
here if you would like to finetune the model from scratch without any pretraining).
- Finetuning on images: To finetune a ViT-H/14 model on a downstream image recognition task (e.g. ImageNet), use
finetune_on_image.py
, e.g.:
python -u finetune_on_image.py \
--train_data_path TRAIN_DATA_PATH \
--val_data_path VAL_TRAIN_DATA_PATH \
--save_prefix INFORMATIVE_SAVE_PREFIX \
--output_dir OUTPUT_DIR \
--finetune SPATIOTEMPORAL_MAE_CHECKPOINT \
--num_classes 1000 \
--model 'vit_huge_patch14' \
--batch_size_per_gpu 4 \
--accum_iter 1 \
--epochs 100000 \
--num_frames 16 \
--input_size 224 \
--pin_mem \
--t_patch_size 2 \
--blr 0.0024 \
--clip_grad 5.0 \
--mixup 0 \
--cutmix 0.0
Here, TRAIN_DATA_PATH
and VAL_TRAIN_DATA_PATH
are the directories containing the training and validation images, respectively, and SPATIOTEMPORAL_MAE_CHECKPOINT
is the path to the pretrained spatiotemporal MAE checkpoint the model is initialized with. This script will effectively make a static video clip for each image by repeating the image 16 times (num_frames
). This allows us to use the pretrained spatiotemporal MAE model as is without any modifications in the architecture.