Skip to content

Latest commit

 

History

History
105 lines (74 loc) · 3.02 KB

README.md

File metadata and controls

105 lines (74 loc) · 3.02 KB

Make-A-Video - Pytorch (wip)

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch. They combine pseudo-3d convolutions (axial convolutions) and temporal attention and show much better temporal fusion.

The pseudo-3d convolutions isn't a new concept. It has been explored before in other contexts, say for protein contact prediction as "dimensional hybrid residual networks".

The gist of the paper comes down to, take a SOTA text-to-image model (here they use DALL-E2, but the same learning points would easily apply to Imagen), make a few minor modifications for attention across time and other ways to skimp on the compute cost, do frame interpolation correctly, get a great video model out.

AI Coffee Break explanation

Install

$ pip install make-a-video

Usage

Passing in video features

import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention

conv = Pseudo3DConv(
    dim = 256,
    kernel_size = 3
)

attn = SpatioTemporalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8
)

video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)

conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)

Passing in images (if one were to pretrain on images first, both temporal convolution and attention will be automatically skipped)

import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention

conv = Pseudo3DConv(
    dim = 256,
    kernel_size = 3
)

attn = SpatioTemporalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8
)

images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)

conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)

You can also control the two modules so that when fed 3-dimensional features, it only does training spatially

import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention

conv = Pseudo3DConv(
    dim = 256,
    kernel_size = 3
)

attn = SpatioTemporalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8
)

video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)

# below it will not train across time

conv_out = conv(video, convolve_across_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, attend_across_time = False) # (1, 256, 8, 16, 16)

Todo

  • wire up dalle2-pytorch unet with pseudo 3d convs + spatial temporal attention
  • give attention the best positional embeddings research has to offer
  • soup up the attention

Citations

@misc{Singer2022,
    author  = {Uriel Singer},
    url     = {https://makeavideo.studio/Make-A-Video.pdf}
}