Skip to content

Commit

Permalink
make sure temporal attention output is 0
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 30, 2022
1 parent 748ce29 commit 8e0e97c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def __init__(
heads = 8,
causal = False,
context_dim = None,
cosine_sim_attn = False
cosine_sim_attn = False,
init_zero = False
):
super().__init__()
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
Expand All @@ -438,6 +439,9 @@ def __init__(
LayerNorm(dim)
)

if init_zero:
nn.init.zeros_(self.to_out[-1].g)

def forward(
self,
x,
Expand Down Expand Up @@ -1291,7 +1295,7 @@ def __init__(
temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1)
temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim)))

temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn})))
temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn, 'init_zero': True})))

# temporal attention relative positional encoding

Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.18.6'
__version__ = '1.18.7'

0 comments on commit 8e0e97c

Please sign in to comment.