From c8792d08bc37381d35f3c986fdd74c6f565890fd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 9 Dec 2022 14:49:23 -0800 Subject: [PATCH] up or down sampling in space can also be turned off --- make_a_video_pytorch/make_a_video.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/make_a_video_pytorch/make_a_video.py b/make_a_video_pytorch/make_a_video.py index 96d1cd6..d3226d1 100644 --- a/make_a_video_pytorch/make_a_video.py +++ b/make_a_video_pytorch/make_a_video.py @@ -259,13 +259,16 @@ class Downsample(nn.Module): def __init__( self, dim, + downsample_space = True, downsample_time = False ): super().__init__() + assert downsample_space or downsample_time + self.down_space = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), nn.Conv2d(dim * 4, dim, 1, bias = False) - ) + ) if downsample_space else None self.down_time = nn.Sequential( Rearrange('b c (f p) h w -> b (c p) f h w', p = 2), @@ -283,7 +286,8 @@ def forward( x = rearrange(x, 'b c f h w -> b f c h w') x, ps = pack([x], '* c h w') - x = self.down_space(x) + if exists(self.down_space): + x = self.down_space(x) if is_video: x, = unpack(x, ps, '* c h w') @@ -300,13 +304,16 @@ class Upsample(nn.Module): def __init__( self, dim, + upsample_space = True, upsample_time = False ): super().__init__() + assert upsample_space or upsample_time + self.up_space = nn.Sequential( nn.Conv2d(dim, dim * 4, 1, bias = False), Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2) - ) + ) if upsample_space else None self.up_time = nn.Sequential( nn.Conv3d(dim, dim * 2, 1, bias = False), @@ -324,7 +331,8 @@ def forward( x = rearrange(x, 'b c f h w -> b f c h w') x, ps = pack([x], '* c h w') - x = self.up_space(x) + if exists(self.up_space): + x = self.up_space(x) if is_video: x, = unpack(x, ps, '* c h w')