Skip to content

Commit

Permalink
exercise attention expertise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 11, 2022
1 parent e634856 commit 5a93a28
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 6 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ video_as_images_out = unet(video, enable_time = False)

## Todo

- [x] give attention the best positional embeddings research has to offer

- [ ] make sure dalle2-pytorch can accept `SpaceTimeUnet` for training
- [ ] give attention the best positional embeddings research has to offer
- [ ] soup up the attention
- [ ] offer a function, similar to how MosaicML's approach, that automatically rigs a 2d-unet from dalle2-pytorch to be 3d
- [ ] consider learned exponential moving average across time from https://github.com/lucidrains/Mega-pytorch

## Citations
Expand Down
76 changes: 73 additions & 3 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,61 @@ def FeedForward(dim, mult = 4):
nn.Linear(inner_dim, bias = False)
)

# best relative positional encoding

class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """

def __init__(
self,
*,
dim,
heads,
num_dims = 1,
layers = 2,
log_dist = True,
cache_rel_pos = False
):
super().__init__()
self.num_dims = num_dims
self.log_dist = log_dist

self.net = nn.ModuleList([])
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))

for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))

self.net.append(nn.Linear(dim, heads))

self.cache_rel_pos = cache_rel_pos
self.register_buffer('rel_pos', None, persistent = False)

@property
def device(self):
return next(self.parameters()).device

def forward(self, *dimensions):
device = self.device

if not exists(self.rel_pos) or not self.cache_rel_pos:
positions = [torch.arange(d, device = device) for d in dimensions]
grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij'))
grid = rearrange(grid, 'c ... -> (...) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')

if self.log_dist:
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)

self.register_buffer('rel_pos', rel_pos, persistent = False)

rel_pos = self.rel_pos.float()

for layer in self.net:
rel_pos = layer(rel_pos)

return rearrange(rel_pos, 'i j h -> h i j')

# helper classes

class Attention(nn.Module):
Expand All @@ -92,7 +147,12 @@ def __init__(

nn.init.zeros_(self.to_out.weight.data) # identity with skip connection

def forward(self, x):
def forward(
self,
x,
rel_pos_bias = None
):

x = self.norm(x)

q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
Expand All @@ -103,6 +163,9 @@ def forward(self, x):

sim = einsum('b h i d, b h j d -> b h i j', q, k)

if exists(rel_pos_bias):
sim = sim + rel_pos_bias

attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
Expand Down Expand Up @@ -175,7 +238,10 @@ def __init__(
):
super().__init__()
self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2)

self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1)

def forward(
self,
Expand All @@ -191,7 +257,9 @@ def forward(
else:
x = rearrange(x, 'b c h w -> b (h w) c')

x = self.spatial_attn(x) + x
space_rel_pos_bias = self.spatial_rel_pos_bias(h, w)

x = self.spatial_attn(x, rel_pos_bias = space_rel_pos_bias) + x

if is_video:
x = rearrange(x, '(b f) (h w) c -> b c f h w', b = b, h = h, w = w)
Expand All @@ -203,7 +271,9 @@ def forward(

x = rearrange(x, 'b c f h w -> (b h w) f c')

x = self.temporal_attn(x) + x
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])

x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x

x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.0.8',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 5a93a28

Please sign in to comment.