Skip to content

Commit

Permalink
skip connections
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 10, 2022
1 parent a4c37ee commit 67565c1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,15 @@ def __init__(
]))

self.ups.append(mlist([
ResnetBlock(dim_out, dim_in),
ResnetBlock(dim_out * 2, dim_in),
ResnetBlock(dim_in, dim_in),
SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None,
Upsample(dim_in, upsample_time = compress_time)
Upsample(dim_out, upsample_time = compress_time)

]))

self.skip_scale = 2 ** -0.5 # paper shows faster convergence

self.conv_in = PseudoConv3d(dim = channels, dim_out = dim, kernel_size = 7, temporal_kernel_size = 3)
self.conv_out = PseudoConv3d(dim = dim, dim_out = channels, kernel_size = 3, temporal_kernel_size = 3)

Expand All @@ -417,27 +419,32 @@ def forward(
):
x = self.conv_in(x, enable_time = enable_time)

hiddens = []

for block1, block2, maybe_attention, downsample in self.downs:
x = block1(x, enable_time = enable_time)
x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)

hiddens.append(x.clone())

x = downsample(x, enable_time = enable_time)

x = self.mid_block1(x, enable_time = enable_time)
x = self.mid_attn(x, enable_time = enable_time)
x = self.mid_block2(x, enable_time = enable_time)

for block1, block2, maybe_attention, upsample in reversed(self.ups):
x = upsample(x, enable_time = enable_time)
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)

x = block1(x, enable_time = enable_time)
x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)

x = upsample(x, enable_time = enable_time)

x = self.conv_out(x, enable_time = enable_time)
return x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 67565c1

Please sign in to comment.