Skip to content

Commit

Permalink
up or down sampling in space can also be turned off
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2022
1 parent 3f4071d commit c8792d0
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,16 @@ class Downsample(nn.Module):
def __init__(
self,
dim,
downsample_space = True,
downsample_time = False
):
super().__init__()
assert downsample_space or downsample_time

self.down_space = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, dim, 1, bias = False)
)
) if downsample_space else None

self.down_time = nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = 2),
Expand All @@ -283,7 +286,8 @@ def forward(
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')

x = self.down_space(x)
if exists(self.down_space):
x = self.down_space(x)

if is_video:
x, = unpack(x, ps, '* c h w')
Expand All @@ -300,13 +304,16 @@ class Upsample(nn.Module):
def __init__(
self,
dim,
upsample_space = True,
upsample_time = False
):
super().__init__()
assert upsample_space or upsample_time

self.up_space = nn.Sequential(
nn.Conv2d(dim, dim * 4, 1, bias = False),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2)
)
) if upsample_space else None

self.up_time = nn.Sequential(
nn.Conv3d(dim, dim * 2, 1, bias = False),
Expand All @@ -324,7 +331,8 @@ def forward(
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')

x = self.up_space(x)
if exists(self.up_space):
x = self.up_space(x)

if is_video:
x, = unpack(x, ps, '* c h w')
Expand Down

0 comments on commit c8792d0

Please sign in to comment.