Skip to content

Commit

Permalink
it all clicked in my head
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2022
1 parent 577a2a6 commit 0af63d7
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 18 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ Passing in video features

```python
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention

conv = Pseudo3DConv(
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
Expand All @@ -53,9 +53,9 @@ Passing in images (if one were to pretrain on images first), both temporal convo

```python
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention

conv = Pseudo3DConv(
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
Expand All @@ -76,9 +76,9 @@ You can also control the two modules so that when fed 3-dimensional features, it

```python
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention

conv = Pseudo3DConv(
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
Expand Down
5 changes: 4 additions & 1 deletion make_a_video_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from make_a_video_pytorch.make_a_video import Pseudo3DConv, SpatioTemporalAttention
from make_a_video_pytorch.make_a_video import PseudoConv3d, SpatioTemporalAttention

from make_a_video_pytorch.make_a_video import ResnetBlock, Downsample, Upsample
from make_a_video_pytorch.make_a_video import SpaceTimeUnet
208 changes: 199 additions & 9 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import nn, einsum
from einops import rearrange

from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange

# helper functions

Expand All @@ -21,7 +23,22 @@ def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
return (x - mean) * var.clamp(min = eps).rsqrt() * self.g

# feedforward

class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.gelu(gate)

def FeedForward(dim, mult = 4):
inner_dim = int(dim * mult * 2 / 3)
return nn.Sequential(
nn.Linear(dim, inner_dim, bias = False),
GEGLU(),
nn.Linear(inner_dim, bias = False)
)

# helper classes

Expand Down Expand Up @@ -65,13 +82,13 @@ def forward(self, x):

# main contribution - pseudo 3d conv

class Pseudo3DConv(nn.Module):
class PseudoConv3d(nn.Module):
def __init__(
self,
dim,
*,
kernel_size,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
Expand All @@ -80,10 +97,11 @@ def __init__(
temporal_kernel_size = default(temporal_kernel_size, kernel_size)

self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2) if kernel_size > 1 else None

nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)

def forward(
self,
Expand All @@ -103,7 +121,7 @@ def forward(
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)

if not convolve_across_time:
if not convolve_across_time or not exists(self.temporal_conv):
return x

x = rearrange(x, 'b c f h w -> (b h w) c f')
Expand Down Expand Up @@ -160,3 +178,175 @@ def forward(
x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h)

return x

# resnet block

class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
kernel_size = 3,
temporal_kernel_size = None,
groups = 8
):
super().__init__()
self.project = PseudoConv3d(dim, dim_out, 3)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()

def forward(
self,
x,
scale_shift = None,
convolve_across_time = False
):
x = self.project(x, convolve_across_time = convolve_across_time)
x = self.norm(x)

if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift

return self.act(x)

class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
time_cond_dim = None,
groups = 8
):
super().__init__()

self.time_mlp = None

if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
)

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(
self,
x,
time_emb = None,
convolve_across_time = True
):

scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)

h = self.block1(x, scale_shift = scale_shift, convolve_across_time = convolve_across_time)

h = self.block2(h, convolve_across_time = convolve_across_time)

return h + self.res_conv(x)

# pixelshuffle upsamples and downsamples
# where time dimension can be configured

class Downsample(nn.Module):
def __init__(
self,
dim,
downsample_time = False
):
super().__init__()
self.down_space = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, dim, 1, bias = False)
)

self.down_time = nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = 2),
nn.Conv3d(dim * 2, dim, 1, bias = False)
) if downsample_time else None

def forward(
self,
x,
downsample_time = True
):
is_video = x.ndim == 5

if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')

x = self.down_space(x)

if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')

if not exists(self.down_time):
return x

x = self.down_time(x)

return x

class Upsample(nn.Module):
def __init__(
self,
dim,
upsample_time = False
):
super().__init__()
self.up_space = nn.Sequential(
nn.Conv2d(dim, dim * 4, 1, bias = False),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2)
)

self.up_time = nn.Sequential(
nn.Conv3d(dim, dim * 2, 1, bias = False),
Rearrange('b (c p) f h w -> b c (f p) h w', p = 2)
) if upsample_time else None

def forward(
self,
x,
upsample_time = True
):
is_video = x.ndim == 5

if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')

x = self.up_space(x)

if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')

if not exists(self.up_time):
return x

x = self.up_time(x)

return x

# space time factorized 3d unet

class SpaceTimeUnet(nn.Module):
def __init__(
self
):
super().__init__()

def forward(
self,
x
):
raise NotImplementedError
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand All @@ -19,7 +19,7 @@
],
install_requires=[
'dalle2-pytorch',
'einops>=0.4',
'einops>=0.6',
'torch>=1.6',
],
classifiers=[
Expand Down

0 comments on commit 0af63d7

Please sign in to comment.