Skip to content

Commit

Permalink
vision realized
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 10, 2022
1 parent 67565c1 commit 1e7531c
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,26 @@ unet = SpaceTimeUnet(
channels = 3,
dim_mult = (1, 2, 4, 8),
temporal_compression = (False, False, False, True),
self_attns = (False, False, True, True)
)
self_attns = (False, False, False, True)
).cuda()

# train on images

images = torch.randn(1, 3, 128, 128).cuda()
images_out = unet(images)

assert images.shape == images_out.shape

# then train on videos

video = torch.randn(1, 3, 16, 128, 128).cuda()
video_out = unet(video)

video = torch.randn(1, 3, 16, 256, 256) # (batch, channels, frame)
pred = unet(video)
assert video_out.shape == video.shape

assert video.shape == pred.shape
# or even treat your videos as images

unet(video, enable_time = False) # treat all frames of video as images
video_as_images_out = unet(video, enable_time = False)
```

## Todo
Expand Down

0 comments on commit 1e7531c

Please sign in to comment.