From 67565c1d63da5d700407521c17779f3a48c56ae5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 10 Dec 2022 10:02:56 -0800 Subject: [PATCH] skip connections --- make_a_video_pytorch/make_a_video.py | 15 +++++++++++---- setup.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/make_a_video_pytorch/make_a_video.py b/make_a_video_pytorch/make_a_video.py index 3cb8501..7cd0998 100644 --- a/make_a_video_pytorch/make_a_video.py +++ b/make_a_video_pytorch/make_a_video.py @@ -400,13 +400,15 @@ def __init__( ])) self.ups.append(mlist([ - ResnetBlock(dim_out, dim_in), + ResnetBlock(dim_out * 2, dim_in), ResnetBlock(dim_in, dim_in), SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None, - Upsample(dim_in, upsample_time = compress_time) + Upsample(dim_out, upsample_time = compress_time) ])) + self.skip_scale = 2 ** -0.5 # paper shows faster convergence + self.conv_in = PseudoConv3d(dim = channels, dim_out = dim, kernel_size = 7, temporal_kernel_size = 3) self.conv_out = PseudoConv3d(dim = dim, dim_out = channels, kernel_size = 3, temporal_kernel_size = 3) @@ -417,6 +419,8 @@ def forward( ): x = self.conv_in(x, enable_time = enable_time) + hiddens = [] + for block1, block2, maybe_attention, downsample in self.downs: x = block1(x, enable_time = enable_time) x = block2(x, enable_time = enable_time) @@ -424,6 +428,8 @@ def forward( if exists(maybe_attention): x = maybe_attention(x, enable_time = enable_time) + hiddens.append(x.clone()) + x = downsample(x, enable_time = enable_time) x = self.mid_block1(x, enable_time = enable_time) @@ -431,13 +437,14 @@ def forward( x = self.mid_block2(x, enable_time = enable_time) for block1, block2, maybe_attention, upsample in reversed(self.ups): + x = upsample(x, enable_time = enable_time) + x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1) + x = block1(x, enable_time = enable_time) x = block2(x, enable_time = enable_time) if exists(maybe_attention): x = maybe_attention(x, enable_time = enable_time) - x = upsample(x, enable_time = enable_time) - x = self.conv_out(x, enable_time = enable_time) return x diff --git a/setup.py b/setup.py index d924776..cb75ea8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'make-a-video-pytorch', packages = find_packages(exclude=[]), - version = '0.0.3', + version = '0.0.4', license='MIT', description = 'Make-A-Video - Pytorch', author = 'Phil Wang',