From db6943ae68cccb8c6fe0ab2219d1a1d0a9906a81 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 18 Mar 2023 09:30:21 -0700 Subject: [PATCH] add token shift along time for feedforwards in transformer blocks in unet3d --- README.md | 9 +++++++++ imagen_pytorch/imagen_video.py | 28 ++++++++++++++++++++++------ imagen_pytorch/version.py | 2 +- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d0468fc..44504b6 100644 --- a/README.md +++ b/README.md @@ -896,3 +896,12 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo year = {2023} } ``` + +```bibtex +@article{Zhang2021TokenST, + title = {Token Shift Transformer for Video Classification}, + author = {Hao Zhang and Y. Hao and Chong-Wah Ngo}, + journal = {Proceedings of the 29th ACM International Conference on Multimedia}, + year = {2021} +} +``` diff --git a/imagen_pytorch/imagen_video.py b/imagen_pytorch/imagen_video.py index 8474810..9a4b8ae 100644 --- a/imagen_pytorch/imagen_video.py +++ b/imagen_pytorch/imagen_video.py @@ -110,6 +110,9 @@ def __init__(self, *args, **kwargs): def forward(self, x, *args, **kwargs): return x +def Sequential(*modules): + return nn.Sequential(*filter(exists, modules)) + # tensor helpers def log(t, eps: float = 1e-12): @@ -1013,12 +1016,22 @@ def FeedForward(dim, mult = 2): nn.Linear(hidden_dim, dim, bias = False) ) -def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width +class TimeTokenShift(nn.Module): + def forward(self, x): + if x.ndim != 5: + return x + + x, x_shift = x.chunk(2, dim = 1) + x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.) + return torch.cat((x, x_shift), dim = 1) + +def ChanFeedForward(dim, mult = 2, time_token_shift = True): # in paper, it seems for self attention layers they did feedforwards with twice channel width hidden_dim = int(dim * mult) - return nn.Sequential( + return Sequential( ChanLayerNorm(dim), Conv2d(dim, hidden_dim, 1, bias = False), nn.GELU(), + TimeTokenShift() if time_token_shift else None, ChanLayerNorm(hidden_dim), Conv2d(hidden_dim, dim, 1, bias = False) ) @@ -1032,6 +1045,7 @@ def __init__( heads = 8, dim_head = 32, ff_mult = 2, + ff_time_token_shift = True, context_dim = None ): super().__init__() @@ -1040,7 +1054,7 @@ def __init__( for _ in range(depth): self.layers.append(nn.ModuleList([ EinopsToAndFrom('b c f h w', 'b (f h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim)), - ChanFeedForward(dim = dim, mult = ff_mult) + ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) ])) def forward(self, x, context = None): @@ -1058,6 +1072,7 @@ def __init__( heads = 8, dim_head = 32, ff_mult = 2, + ff_time_token_shift = True, context_dim = None, **kwargs ): @@ -1067,7 +1082,7 @@ def __init__( for _ in range(depth): self.layers.append(nn.ModuleList([ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), - ChanFeedForward(dim = dim, mult = ff_mult) + ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) ])) def forward(self, x, context = None): @@ -1200,6 +1215,7 @@ def __init__( attn_dim_head = 64, attn_heads = 8, ff_mult = 2., + ff_time_token_shift = True, # this would do a token shift along time axis, at the hidden layer within feedforwards - from successful use in RWKV (Peng et al), and other token shift video transformer works lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ layer_attns = False, layer_attns_depth = 1, @@ -1452,7 +1468,7 @@ def __init__( pre_downsample, resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), - transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), + transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), temporal_peg(current_dim), temporal_attn(current_dim), TemporalDownsample(current_dim, stride = temporal_stride) if temporal_stride > 1 else None, @@ -1490,7 +1506,7 @@ def __init__( self.ups.append(nn.ModuleList([ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), - transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), + transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), temporal_peg(dim_out), temporal_attn(dim_out), TemporalPixelShuffleUpsample(dim_out, stride = temporal_stride) if temporal_stride > 1 else None, diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 3335368..a039e2f 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.23.0' +__version__ = '1.23.1'