Skip to content

Commit

Permalink
add token shift along time for feedforwards in transformer blocks in …
Browse files Browse the repository at this point in the history
…unet3d
  • Loading branch information
lucidrains committed Mar 18, 2023
1 parent 7a21a30 commit db6943a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,12 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
year = {2023}
}
```

```bibtex
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
```
28 changes: 22 additions & 6 deletions imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __init__(self, *args, **kwargs):
def forward(self, x, *args, **kwargs):
return x

def Sequential(*modules):
return nn.Sequential(*filter(exists, modules))

# tensor helpers

def log(t, eps: float = 1e-12):
Expand Down Expand Up @@ -1013,12 +1016,22 @@ def FeedForward(dim, mult = 2):
nn.Linear(hidden_dim, dim, bias = False)
)

def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
class TimeTokenShift(nn.Module):
def forward(self, x):
if x.ndim != 5:
return x

x, x_shift = x.chunk(2, dim = 1)
x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.)
return torch.cat((x, x_shift), dim = 1)

def ChanFeedForward(dim, mult = 2, time_token_shift = True): # in paper, it seems for self attention layers they did feedforwards with twice channel width
hidden_dim = int(dim * mult)
return nn.Sequential(
return Sequential(
ChanLayerNorm(dim),
Conv2d(dim, hidden_dim, 1, bias = False),
nn.GELU(),
TimeTokenShift() if time_token_shift else None,
ChanLayerNorm(hidden_dim),
Conv2d(hidden_dim, dim, 1, bias = False)
)
Expand All @@ -1032,6 +1045,7 @@ def __init__(
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None
):
super().__init__()
Expand All @@ -1040,7 +1054,7 @@ def __init__(
for _ in range(depth):
self.layers.append(nn.ModuleList([
EinopsToAndFrom('b c f h w', 'b (f h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim)),
ChanFeedForward(dim = dim, mult = ff_mult)
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
]))

def forward(self, x, context = None):
Expand All @@ -1058,6 +1072,7 @@ def __init__(
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None,
**kwargs
):
Expand All @@ -1067,7 +1082,7 @@ def __init__(
for _ in range(depth):
self.layers.append(nn.ModuleList([
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
ChanFeedForward(dim = dim, mult = ff_mult)
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
]))

def forward(self, x, context = None):
Expand Down Expand Up @@ -1200,6 +1215,7 @@ def __init__(
attn_dim_head = 64,
attn_heads = 8,
ff_mult = 2.,
ff_time_token_shift = True, # this would do a token shift along time axis, at the hidden layer within feedforwards - from successful use in RWKV (Peng et al), and other token shift video transformer works
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
layer_attns = False,
layer_attns_depth = 1,
Expand Down Expand Up @@ -1452,7 +1468,7 @@ def __init__(
pre_downsample,
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
temporal_peg(current_dim),
temporal_attn(current_dim),
TemporalDownsample(current_dim, stride = temporal_stride) if temporal_stride > 1 else None,
Expand Down Expand Up @@ -1490,7 +1506,7 @@ def __init__(
self.ups.append(nn.ModuleList([
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
temporal_peg(dim_out),
temporal_attn(dim_out),
TemporalPixelShuffleUpsample(dim_out, stride = temporal_stride) if temporal_stride > 1 else None,
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.23.0'
__version__ = '1.23.1'

0 comments on commit db6943a

Please sign in to comment.