Skip to content

Commit

Permalink
2 skip connections per stage, and also make sure properly init upsamp…
Browse files Browse the repository at this point in the history
…le pixelshuffle
  • Loading branch information
lucidrains committed Dec 11, 2022
1 parent 58f4172 commit e634856
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
47 changes: 39 additions & 8 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn, einsum

from einops import rearrange, pack, unpack
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# helper functions
Expand Down Expand Up @@ -293,19 +293,22 @@ def __init__(
self,
dim,
downsample_space = True,
downsample_time = False
downsample_time = False,
nonlin = False
):
super().__init__()
assert downsample_space or downsample_time

self.down_space = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, dim, 1, bias = False)
nn.Conv2d(dim * 4, dim, 1, bias = False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_space else None

self.down_time = nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = 2),
nn.Conv3d(dim * 2, dim, 1, bias = False)
nn.Conv3d(dim * 2, dim, 1, bias = False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_time else None

def forward(
Expand Down Expand Up @@ -338,21 +341,42 @@ def __init__(
self,
dim,
upsample_space = True,
upsample_time = False
upsample_time = False,
nonlin = False
):
super().__init__()
assert upsample_space or upsample_time

self.up_space = nn.Sequential(
nn.Conv2d(dim, dim * 4, 1, bias = False),
nn.Conv2d(dim, dim * 4, 1),
nn.SiLU() if nonlin else nn.Identity(),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2)
) if upsample_space else None

self.up_time = nn.Sequential(
nn.Conv3d(dim, dim * 2, 1, bias = False),
nn.Conv3d(dim, dim * 2, 1),
nn.SiLU() if nonlin else nn.Identity(),
Rearrange('b (c p) f h w -> b c (f p) h w', p = 2)
) if upsample_time else None

self.init_()

def init_(self):
if exists(self.up_space):
self.init_conv_(self.up_space[0], 4)

if exists(self.up_time):
self.init_conv_(self.up_time[0], 2)

def init_conv_(self, conv, factor):
o, *remain_dims = conv.weight.shape
conv_weight = torch.empty(o // factor, *remain_dims)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = factor)

conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)

def forward(
self,
x,
Expand Down Expand Up @@ -444,7 +468,7 @@ def __init__(

self.ups.append(mlist([
ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim = timestep_cond_dim),
ResnetBlock(dim_in, dim_in),
ResnetBlock(dim_in + dim_out, dim_in),
SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None,
Upsample(dim_out, upsample_time = compress_time)

Expand Down Expand Up @@ -484,6 +508,9 @@ def forward(

for block1, block2, maybe_attention, downsample in self.downs:
x = block1(x, t, enable_time = enable_time)

hiddens.append(x.clone())

x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
Expand All @@ -499,9 +526,13 @@ def forward(

for block1, block2, maybe_attention, upsample in reversed(self.ups):
x = upsample(x, enable_time = enable_time)

x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)

x = block1(x, t, enable_time = enable_time)

x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)

x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e634856

Please sign in to comment.