Skip to content

An optimized implementation of spatiotemporal masked autoencoders

License

Notifications You must be signed in to change notification settings

eminorhan/optimized-stmae

Repository files navigation

Optimized Spatiotemporal Masked Autoencoders (ST-MAEs)

A lean, optimized implementation of spatiotemporal masked autoencoders (ST-MAEs). The skeleton of the code is recycled from Facebook's ST-MAE repository with various simplifications. The following optimizations are implemented:

  • FlashAttention-2
  • torch.compile
  • fused AdamW
  • mixed precision training (torch.cuda.amp)
  • DDP for distributed training
  • selective decoding of videos

These optimizations allow us to achieve a very high training throughput: e.g. on merely 4 H100 GPUs, in roughly 1 week, we were able to complete over 160 epochs of training on Kinetics-700 (~536K videos) with a ViT-H encoder (~633M parameters) with 8x16x16=2048 spatiotemporal input "tokens" (i.e. 8 tokens in the temporal dimension and 16x16 tokens in the spatial dimensions) with a masking ratio of 90% and an effective batch size of 256 videos (64 videos on each GPU).

Dependence of model definitions on the timm library is also removed in this implementation, so the code is self-contained except for the standard libraries. The code was tested with pytorch==2.2.0 and torchvision==0.17.0.

Usage examples

  • Training: To train a spatiotemporal MAE model with a ViT-H/14 architecture from scratch on your data, use pretrain.py, e.g.:
python -u pretrain.py \
    --data_dirs DATA_DIRS \
    --datafile_dir DATAFILE_DIR \
    --save_prefix INFORMATIVE_SAVE_PREFIX \
    --output_dir OUTPUT_DIR \
    --model 'mae_vit_huge_patch14' \
    --batch_size_per_gpu 1 \
    --accum_iter 1 \
    --epochs 100000 \
    --num_frames 16 \
    --img_size 224 \
    --decoder_embed_dim 512 \
    --decoder_depth 4 \
    --pin_mem \
    --t_patch_size 2 \
    --repeat_aug 16 \
    --sampling_rate 8 \
    --lr 0.0001 \
    --weight_decay 0.05 \
    --mask_ratio 0.9 \
    --pred_t_dim 16 \
    --clip_grad 0.1

Here, DATA_DIRS is a list of directories containing the video files, DATAFILE_DIR is the directory where a .csv file containing all the training video file paths (optionally, with the corresponding class labels) will be saved, and OUTPUT_DIR is the directory where the checkpoints and training logs will be saved.

  • Finetuning on videos: To finetune a ViT-H/14 model on a downstream video recognition task, use finetune.py, e.g.:
python -u finetune.py \
    --train_dir TRAIN_DIR \
    --val_dir VAL_DIR \
    --datafile_dir DATAFILE_DIR \
    --save_prefix INFORMATIVE_SAVE_PREFIX \
    --output_dir OUTPUT_DIR \
    --finetune SPATIOTEMPORAL_MAE_CHECKPOINT \
    --num_classes 174 \
    --model 'vit_huge_patch14' \
    --batch_size_per_gpu 4 \
    --accum_iter 1 \
    --epochs 100000 \
    --num_frames 16 \
    --input_size 224 \
    --pin_mem \
    --t_patch_size 2 \
    --repeat_aug 1 \
    --sampling_rate 8 \
    --blr 0.0024 \
    --clip_grad 5.0 \
    --mixup 0 \
    --cutmix 0.0

Here, TRAIN_DIR and VAL_DIR are the directories containing the training and validation videos, respectively, and SPATIOTEMPORAL_MAE_CHECKPOINT is the path to the pretrained spatiotemporal MAE checkpoint the model is initialized with (use "" here if you would like to finetune the model from scratch without any pretraining).

  • Finetuning on images: To finetune a ViT-H/14 model on a downstream image recognition task (e.g. ImageNet), use finetune_on_image.py, e.g.:
python -u finetune_on_image.py \
    --train_data_path TRAIN_DATA_PATH \
    --val_data_path VAL_TRAIN_DATA_PATH \
    --save_prefix INFORMATIVE_SAVE_PREFIX \
    --output_dir OUTPUT_DIR \
    --finetune SPATIOTEMPORAL_MAE_CHECKPOINT \
    --num_classes 1000 \
    --model 'vit_huge_patch14' \
    --batch_size_per_gpu 4 \
    --accum_iter 1 \
    --epochs 100000 \
    --num_frames 16 \
    --input_size 224 \
    --pin_mem \
    --t_patch_size 2 \
    --blr 0.0024 \
    --clip_grad 5.0 \
    --mixup 0 \
    --cutmix 0.0

Here, TRAIN_DATA_PATH and VAL_TRAIN_DATA_PATH are the directories containing the training and validation images, respectively, and SPATIOTEMPORAL_MAE_CHECKPOINT is the path to the pretrained spatiotemporal MAE checkpoint the model is initialized with. This script will effectively make a static video clip for each image by repeating the image 16 times (num_frames). This allows us to use the pretrained spatiotemporal MAE model as is without any modifications in the architecture.

Releases

No releases published

Packages

No packages published

Languages