Skip to content

Commit

Permalink
a full space time unet where time can be ignored
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 10, 2022
1 parent c8792d0 commit a4c37ee
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 21 deletions.
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,31 @@ video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width

# below it will not train across time

conv_out = conv(video, convolve_across_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, attend_across_time = False) # (1, 256, 8, 16, 16)
conv_out = conv(video, enable_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, enable_time = False) # (1, 256, 8, 16, 16)
```

Full `SpaceTimeUnet` that is agnostic to images or video training, and where even if video is passed in, time can be ignored


```python
import torch
from make_a_video_pytorch import SpaceTimeUnet

unet = SpaceTimeUnet(
dim = 64,
channels = 3,
dim_mult = (1, 2, 4, 8),
temporal_compression = (False, False, False, True),
self_attns = (False, False, True, True)
)

video = torch.randn(1, 3, 16, 256, 256) # (batch, channels, frame)
pred = unet(video)

assert video.shape == pred.shape

unet(video, enable_time = False) # treat all frames of video as images
```

## Todo
Expand Down
119 changes: 101 additions & 18 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
import functools
from operator import mul
from torch import nn, einsum

from einops import rearrange, pack, unpack
Expand All @@ -12,6 +14,14 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def mul_reduce(tup):
return functools.reduce(mul, tup)

def divisible_by(numer, denom):
return (numer % denom) == 0

mlist = nn.ModuleList

# layernorm 3d

class LayerNorm(nn.Module):
Expand Down Expand Up @@ -106,12 +116,12 @@ def __init__(
def forward(
self,
x,
convolve_across_time = True
enable_time = True
):
b, c, *_, h, w = x.shape

is_video = x.ndim == 5
convolve_across_time &= is_video
enable_time &= is_video

if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
Expand All @@ -121,7 +131,7 @@ def forward(
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)

if not convolve_across_time or not exists(self.temporal_conv):
if not enable_time or not exists(self.temporal_conv):
return x

x = rearrange(x, 'b c f h w -> (b h w) c f')
Expand Down Expand Up @@ -150,11 +160,11 @@ def __init__(
def forward(
self,
x,
attend_across_time = True
enable_time = True
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
attend_across_time &= is_video
enable_time &= is_video

if is_video:
x = rearrange(x, 'b c f h w -> (b f) (h w) c')
Expand All @@ -168,7 +178,7 @@ def forward(
else:
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)

if not attend_across_time:
if not enable_time:
return x

x = rearrange(x, 'b c f h w -> (b h w) f c')
Expand Down Expand Up @@ -199,9 +209,9 @@ def forward(
self,
x,
scale_shift = None,
convolve_across_time = False
enable_time = False
):
x = self.project(x, convolve_across_time = convolve_across_time)
x = self.project(x, enable_time = enable_time)
x = self.norm(x)

if exists(scale_shift):
Expand Down Expand Up @@ -237,7 +247,7 @@ def forward(
self,
x,
time_emb = None,
convolve_across_time = True
enable_time = True
):

scale_shift = None
Expand All @@ -246,9 +256,9 @@ def forward(
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)

h = self.block1(x, scale_shift = scale_shift, convolve_across_time = convolve_across_time)
h = self.block1(x, scale_shift = scale_shift, enable_time = enable_time)

h = self.block2(h, convolve_across_time = convolve_across_time)
h = self.block2(h, enable_time = enable_time)

return h + self.res_conv(x)

Expand Down Expand Up @@ -278,7 +288,7 @@ def __init__(
def forward(
self,
x,
downsample_time = True
enable_time = True
):
is_video = x.ndim == 5

Expand All @@ -293,7 +303,7 @@ def forward(
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')

if not is_video or not exists(self.down_time):
if not is_video or not exists(self.down_time) or not enable_time:
return x

x = self.down_time(x)
Expand Down Expand Up @@ -323,7 +333,7 @@ def __init__(
def forward(
self,
x,
upsample_time = True
enable_time = True
):
is_video = x.ndim == 5

Expand All @@ -338,7 +348,7 @@ def forward(
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')

if not is_video or not exists(self.up_time):
if not is_video or not exists(self.up_time) or not enable_time:
return x

x = self.up_time(x)
Expand All @@ -349,12 +359,85 @@ def forward(

class SpaceTimeUnet(nn.Module):
def __init__(
self
self,
*,
dim,
channels = 3,
dim_mult = (1, 2, 4, 8),
self_attns = (False, False, False, True),
temporal_compression = (False, True, True, True),
attn_dim_head = 64,
attn_heads = 8
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression)
num_layers = len(dim_mult)

dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
dim_in_out = zip(dims[:-1], dims[1:])

self.downs = mlist([])
self.ups = mlist([])

attn_kwargs = dict(
dim_head = attn_dim_head,
heads = attn_heads
)

mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim)
self.mid_attn = SpatioTemporalAttention(dim = mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim)

for _, self_attend, (dim_in, dim_out), compress_time in zip(range(num_layers), self_attns, dim_in_out, temporal_compression):

self.downs.append(mlist([
ResnetBlock(dim_in, dim_out),
ResnetBlock(dim_out, dim_out),
SpatioTemporalAttention(dim = dim_out, **attn_kwargs) if self_attend else None,
Downsample(dim_out, downsample_time = compress_time)
]))

self.ups.append(mlist([
ResnetBlock(dim_out, dim_in),
ResnetBlock(dim_in, dim_in),
SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None,
Upsample(dim_in, upsample_time = compress_time)

]))

self.conv_in = PseudoConv3d(dim = channels, dim_out = dim, kernel_size = 7, temporal_kernel_size = 3)
self.conv_out = PseudoConv3d(dim = dim, dim_out = channels, kernel_size = 3, temporal_kernel_size = 3)

def forward(
self,
x
x,
enable_time = True
):
raise NotImplementedError
x = self.conv_in(x, enable_time = enable_time)

for block1, block2, maybe_attention, downsample in self.downs:
x = block1(x, enable_time = enable_time)
x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)

x = downsample(x, enable_time = enable_time)

x = self.mid_block1(x, enable_time = enable_time)
x = self.mid_attn(x, enable_time = enable_time)
x = self.mid_block2(x, enable_time = enable_time)

for block1, block2, maybe_attention, upsample in reversed(self.ups):
x = block1(x, enable_time = enable_time)
x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)

x = upsample(x, enable_time = enable_time)

x = self.conv_out(x, enable_time = enable_time)
return x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a4c37ee

Please sign in to comment.