Skip to content

Commit

Permalink
integrate the pseudo 3d convs into all resnet blocks within the unet3…
Browse files Browse the repository at this point in the history
…d - mashup of imagen video + make-a-video
  • Loading branch information
lucidrains committed Dec 12, 2022
1 parent 3046f1f commit 6c8236e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ videos.shape # (4, 3, 20, 32, 32)

```

You can also train on text - image pairs first. The Unet3d will automatically convert it to single framed videos and learn without the temporal components, whether convolutions or attention.

This is the approach taken by all the big artificial intelligence labs (Brain, MetaAI, Bytedance)

## FAQ

- Why are my generated images not aligning well with the text?
Expand Down
6 changes: 5 additions & 1 deletion imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def noise_distribution(self, P_mean, P_std, batch_size):

def forward(
self,
images,
images, # rename to images or video
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
texts: List[str] = None,
text_embeds = None,
Expand All @@ -696,6 +696,10 @@ def forward(
cond_images = None,
**kwargs
):
if self.is_video and images.ndim == 4:
images = rearrange(images, 'b c h w -> b c 1 h w')
kwargs.update(ignore_time = True)

assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
Expand Down
6 changes: 5 additions & 1 deletion imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,7 +2459,7 @@ def p_losses(
@beartype
def forward(
self,
images,
images, # rename to images or video
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
texts: List[str] = None,
text_embeds = None,
Expand All @@ -2468,6 +2468,10 @@ def forward(
cond_images = None,
**kwargs
):
if self.is_video and images.ndim == 4:
images = rearrange(images, 'b c h w -> b c 1 h w')
kwargs.update(ignore_time = True)

assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.18.1'
__version__ = '1.18.3'

0 comments on commit 6c8236e

Please sign in to comment.