diff --git a/README.md b/README.md index b63c736..2e5ac7d 100644 --- a/README.md +++ b/README.md @@ -134,10 +134,10 @@ video_as_images_out = unet(video, enable_time = False) ## Todo +- [x] give attention the best positional embeddings research has to offer + - [ ] make sure dalle2-pytorch can accept `SpaceTimeUnet` for training -- [ ] give attention the best positional embeddings research has to offer - [ ] soup up the attention -- [ ] offer a function, similar to how MosaicML's approach, that automatically rigs a 2d-unet from dalle2-pytorch to be 3d - [ ] consider learned exponential moving average across time from https://github.com/lucidrains/Mega-pytorch ## Citations diff --git a/make_a_video_pytorch/make_a_video.py b/make_a_video_pytorch/make_a_video.py index 1266a8f..ffe6e45 100644 --- a/make_a_video_pytorch/make_a_video.py +++ b/make_a_video_pytorch/make_a_video.py @@ -70,6 +70,61 @@ def FeedForward(dim, mult = 4): nn.Linear(inner_dim, bias = False) ) +# best relative positional encoding + +class ContinuousPositionBias(nn.Module): + """ from https://arxiv.org/abs/2111.09883 """ + + def __init__( + self, + *, + dim, + heads, + num_dims = 1, + layers = 2, + log_dist = True, + cache_rel_pos = False + ): + super().__init__() + self.num_dims = num_dims + self.log_dist = log_dist + + self.net = nn.ModuleList([]) + self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU())) + + for _ in range(layers - 1): + self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU())) + + self.net.append(nn.Linear(dim, heads)) + + self.cache_rel_pos = cache_rel_pos + self.register_buffer('rel_pos', None, persistent = False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *dimensions): + device = self.device + + if not exists(self.rel_pos) or not self.cache_rel_pos: + positions = [torch.arange(d, device = device) for d in dimensions] + grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij')) + grid = rearrange(grid, 'c ... -> (...) c') + rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') + + if self.log_dist: + rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1) + + self.register_buffer('rel_pos', rel_pos, persistent = False) + + rel_pos = self.rel_pos.float() + + for layer in self.net: + rel_pos = layer(rel_pos) + + return rearrange(rel_pos, 'i j h -> h i j') + # helper classes class Attention(nn.Module): @@ -92,7 +147,12 @@ def __init__( nn.init.zeros_(self.to_out.weight.data) # identity with skip connection - def forward(self, x): + def forward( + self, + x, + rel_pos_bias = None + ): + x = self.norm(x) q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1) @@ -103,6 +163,9 @@ def forward(self, x): sim = einsum('b h i d, b h j d -> b h i j', q, k) + if exists(rel_pos_bias): + sim = sim + rel_pos_bias + attn = sim.softmax(dim = -1) out = einsum('b h i j, b h j d -> b h i d', attn, v) @@ -175,7 +238,10 @@ def __init__( ): super().__init__() self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) + self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2) + self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) + self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) def forward( self, @@ -191,7 +257,9 @@ def forward( else: x = rearrange(x, 'b c h w -> b (h w) c') - x = self.spatial_attn(x) + x + space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) + + x = self.spatial_attn(x, rel_pos_bias = space_rel_pos_bias) + x if is_video: x = rearrange(x, '(b f) (h w) c -> b c f h w', b = b, h = h, w = w) @@ -203,7 +271,9 @@ def forward( x = rearrange(x, 'b c f h w -> (b h w) f c') - x = self.temporal_attn(x) + x + time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) + + x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h) diff --git a/setup.py b/setup.py index 139c64d..94e95f3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'make-a-video-pytorch', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.0.8', license='MIT', description = 'Make-A-Video - Pytorch', author = 'Phil Wang',