Skip to content

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch

License

Notifications You must be signed in to change notification settings

lucidrains/make-a-video-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Make-A-Video - Pytorch (wip)

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch. They combine pseudo-3d convolutions (axial convolutions) and temporal attention and show much better temporal fusion.

The pseudo-3d convolutions isn't a new concept. It has been explored before in other contexts, say for protein contact prediction as "dimensional hybrid residual networks".

The gist of the paper comes down to, take a SOTA text-to-image model (here they use DALL-E2, but the same learning points would easily apply to Imagen), make a few minor modifications for attention across time and other ways to skimp on the compute cost, do frame interpolation correctly, get a great video model out.

AI Coffee Break explanation

Appreciation

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

  • Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper

  • Alex for einops, an abstraction that is simply genius. No other word for it.

Install

$ pip install make-a-video-pytorch

Usage

Passing in video features

import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention

conv = PseudoConv3d(
    dim = 256,
    kernel_size = 3
)

attn = SpatioTemporalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8
)

video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)

conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)

Passing in images (if one were to pretrain on images first), both temporal convolution and attention will be automatically skipped. In other words, you can use this straightforwardly in your 2d Unet and then port it over to a 3d Unet once that phase of the training is done. The temporal modules are initialized to output identity as the paper had done.

import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention

conv = PseudoConv3d(
    dim = 256,
    kernel_size = 3
)

attn = SpatioTemporalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8
)

images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)

conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)

You can also control the two modules so that when fed 3-dimensional features, it only does training spatially

import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention

conv = PseudoConv3d(
    dim = 256,
    kernel_size = 3
)

attn = SpatioTemporalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8
)

video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)

# below it will not train across time

conv_out = conv(video, enable_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, enable_time = False) # (1, 256, 8, 16, 16)

Full SpaceTimeUnet that is agnostic to images or video training, and where even if video is passed in, time can be ignored

import torch
from make_a_video_pytorch import SpaceTimeUnet

unet = SpaceTimeUnet(
    dim = 64,
    channels = 3,
    dim_mult = (1, 2, 4, 8),
    temporal_compression = (False, False, False, True),
    self_attns = (False, False, True, True)
)

video = torch.randn(1, 3, 16, 256, 256) # (batch, channels, frame)
pred = unet(video)

assert video.shape == pred.shape

unet(video, enable_time = False) # treat all frames of video as images

Todo

  • wire up dalle2-pytorch unet with pseudo 3d convs + spatial temporal attention
  • give attention the best positional embeddings research has to offer
  • soup up the attention
  • offer a function, similar to how MosaicML's approach, that automatically rigs a 2d-unet from dalle2-pytorch to be 3d
  • consider learned exponential moving average across time from https://github.com/lucidrains/Mega-pytorch

Citations

@misc{Singer2022,
    author  = {Uriel Singer},
    url     = {https://makeavideo.studio/Make-A-Video.pdf}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}