diff --git a/README.md b/README.md index 53df659..96b50d6 100644 --- a/README.md +++ b/README.md @@ -616,6 +616,10 @@ videos.shape # (4, 3, 20, 32, 32) ``` +You can also train on text - image pairs first. The Unet3d will automatically convert it to single framed videos and learn without the temporal components, whether convolutions or attention. + +This is the approach taken by all the big artificial intelligence labs (Brain, MetaAI, Bytedance) + ## FAQ - Why are my generated images not aligning well with the text? diff --git a/imagen_pytorch/elucidated_imagen.py b/imagen_pytorch/elucidated_imagen.py index 9d69e12..ece363b 100644 --- a/imagen_pytorch/elucidated_imagen.py +++ b/imagen_pytorch/elucidated_imagen.py @@ -687,7 +687,7 @@ def noise_distribution(self, P_mean, P_std, batch_size): def forward( self, - images, + images, # rename to images or video unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None, texts: List[str] = None, text_embeds = None, @@ -696,6 +696,10 @@ def forward( cond_images = None, **kwargs ): + if self.is_video and images.ndim == 4: + images = rearrange(images, 'b c h w -> b c 1 h w') + kwargs.update(ignore_time = True) + assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index f69cb70..810f9ea 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -2459,7 +2459,7 @@ def p_losses( @beartype def forward( self, - images, + images, # rename to images or video unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None, texts: List[str] = None, text_embeds = None, @@ -2468,6 +2468,10 @@ def forward( cond_images = None, **kwargs ): + if self.is_video and images.ndim == 4: + images = rearrange(images, 'b c h w -> b c 1 h w') + kwargs.update(ignore_time = True) + assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index acc3cb8..5f7025b 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.18.1' +__version__ = '1.18.3'