Skip to content

Commit

Permalink
some asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 10, 2022
1 parent 8b7be7f commit 58f4172
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
18 changes: 18 additions & 0 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,11 @@ def __init__(
dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
dim_in_out = zip(dims[:-1], dims[1:])

# determine the valid multiples of the image size and frames of the video

self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression)))
self.image_size_multiple = 2 ** num_layers

# timestep conditioning for DDPM, not to be confused with the time dimension of the video

self.to_timestep_cond = None
Expand Down Expand Up @@ -456,7 +461,20 @@ def forward(
timestep = None,
enable_time = True
):

# some asserts

assert not (exists(self.to_timestep_cond) ^ exists(timestep))
is_video = x.ndim == 5

if enable_time and is_video:
frames = x.shape[2]
assert divisible_by(frames, self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})'

height, width = x.shape[-2:]
assert divisible_by(height, self.image_size_multiple) and divisible_by(width, self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}'

# main logic

t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 58f4172

Please sign in to comment.