Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch. They combine pseudo-3d convolutions (axial convolutions) and temporal attention and show much better temporal fusion.
The pseudo-3d convolutions isn't a new concept. It has been explored before in other contexts, say for protein contact prediction as "dimensional hybrid residual networks".
The gist of the paper comes down to, take a SOTA text-to-image model (here they use DALL-E2, but the same learning points would easily apply to Imagen), make a few minor modifications for attention across time and other ways to skimp on the compute cost, do frame interpolation correctly, get a great video model out.
$ pip install make-a-video
Passing in video features
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention
conv = Pseudo3DConv(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)
Passing in images (if one were to pretrain on images first, both temporal convolution and attention will be automatically skipped)
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention
conv = Pseudo3DConv(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)
conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)
You can also control the two modules so that when fed 3-dimensional features, it only does training spatially
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention
conv = Pseudo3DConv(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
# below it will not train across time
conv_out = conv(video, convolve_across_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, attend_across_time = False) # (1, 256, 8, 16, 16)
- wire up dalle2-pytorch unet with pseudo 3d convs + spatial temporal attention
- give attention the best positional embeddings research has to offer
- soup up the attention
@misc{Singer2022,
author = {Uriel Singer},
url = {https://makeavideo.studio/Make-A-Video.pdf}
}