diff --git a/README.md b/README.md
index 0182b33..a93e219 100644
--- a/README.md
+++ b/README.md
@@ -523,6 +523,23 @@ inpainted_images = trainer.sample(texts = [
inpainted_images # (4, 3, 512, 512)
```
+For video, similarly pass in your videos to `inpaint_images` keyword on `.sample`. For now, the `inpaint_masks` will still have to be a single mask across frames.
+
+```python
+
+inpaint_videos = torch.randn(4, 3, 8, 512, 512).cuda() # (batch, channels, frames, height, width)
+inpaint_masks = torch.ones((4, 512, 512)).bool().cuda() # (batch, height, width)
+
+inpainted_videos = trainer.sample(texts = [
+ 'a whale breaching from afar',
+ 'young girl blowing out candles on her birthday cake',
+ 'fireworks with blue and green sparkles',
+ 'dust motes swirling in the morning sunshine on the windowsill'
+], inpaint_images = inpaint_videos, inpaint_masks = inpaint_masks, cond_scale = 5.)
+
+inpainted_videos # (4, 3, 8, 512, 512)
+```
+
## Experimental
Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of `Imagen`, the `ElucidatedImagen`, so that one can use the new elucidated DDPM for text-guided cascading generation.
@@ -705,7 +722,9 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [x] make sure one can customize all interpolation modes (some researchers are finding better results with trilinear)
- [x] imagen-video : allow for conditioning on preceding (and possibly future) frames of videos. ignore time should not be allowed in that scenario
- [x] make sure to automatically take care of temporal down/upsampling for conditioning video frames, but allow for an option to turn it off
+- [x] make sure inpainting works with video
+- [ ] make sure inpainting mask for video can accept be customized per frame
- [ ] reread cogvideo and figure out how frame rate conditioning could be used
- [ ] bring in attention expertise for self attention layers in unet3d
- [ ] consider bringing in NUWA's 3d convolutional attention
@@ -713,7 +732,6 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [ ] consider perceiver-ar approach to attending to past time
- [ ] frame dropouts during attention for achieving both regularizing effect as well as shortened training time
- [ ] investigate frank wood's claims https://github.com/lucidrains/flexible-diffusion-modeling-videos-pytorch and either add the hierarchical sampling technique, or let people know about its deficiencies
-- [ ] make sure inpainting works with video
- [ ] offer challenging moving mnist (with distractor objects) as a one-line trainable baseline for researchers to branch off of for text to video
- [ ] preencoding of text to memmapped embeddings
- [ ] be able to create dataloader iterators based on the old epoch style, also configure shuffling etc
diff --git a/imagen_pytorch/elucidated_imagen.py b/imagen_pytorch/elucidated_imagen.py
index 47d4363..cbb791a 100644
--- a/imagen_pytorch/elucidated_imagen.py
+++ b/imagen_pytorch/elucidated_imagen.py
@@ -619,6 +619,14 @@ def sample(
# handle video and frame dimension
+ if self.is_video and exists(inpaint_images):
+ video_frames = inpaint_images.shape[2]
+
+ if inpaint_masks.ndim == 3:
+ inpaint_masks = rearrange(inpaint_masks, 'b h w -> b 1 h w')
+
+ assert inpaint_masks.shape[1] == 1, 'for now, inpainting video can only accept a single mask across frames'
+
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
# determine the frame dimensions, if needed
diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py
index dd1ebf8..c726420 100644
--- a/imagen_pytorch/imagen_pytorch.py
+++ b/imagen_pytorch/imagen_pytorch.py
@@ -2360,6 +2360,14 @@ def sample(
# add frame dimension for video
+ if self.is_video and exists(inpaint_images):
+ video_frames = inpaint_images.shape[2]
+
+ if inpaint_masks.ndim == 3:
+ inpaint_masks = rearrange(inpaint_masks, 'b h w -> b 1 h w')
+
+ assert inpaint_masks.shape[1] == 1, 'for now, inpainting video can only accept a single mask across frames'
+
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py
index fddb185..6376bae 100644
--- a/imagen_pytorch/version.py
+++ b/imagen_pytorch/version.py
@@ -1 +1 @@
-__version__ = '1.23.3'
+__version__ = '1.24.0'