Skip to content

Commit

Permalink
ability to do causal attention for time axis
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 19, 2023
1 parent f2d214f commit 16561b3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
20 changes: 18 additions & 2 deletions make_a_video_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False
flash = False,
causal = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.causal = causal

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

Expand Down Expand Up @@ -76,7 +79,8 @@ def flash_attn(self, q, k, v):
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)

return out
Expand All @@ -102,6 +106,18 @@ def forward(self, q, k, v, bias = None):

sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

# attn bias

if exists(bias):
sim = sim + bias

# causal

if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim = -1)
Expand Down
16 changes: 10 additions & 6 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,15 @@ def __init__(
dim,
dim_head = 64,
heads = 8,
flash = False
flash = False,
causal = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads

self.attend = Attend(flash = flash)
self.attend = Attend(flash = flash, causal = causal)

self.norm = RMSNorm(dim, dim = -1)

Expand Down Expand Up @@ -281,15 +282,16 @@ def __init__(
add_feed_forward = True,
ff_mult = 4,
pos_bias = True,
flash = False
flash = False,
causal_time_attn = False
):
super().__init__()
assert not (flash and pos_bias), 'learned positional attention bias is not compatible with flash attention'

self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash)
self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2) if pos_bias else None

self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash)
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash, causal = causal_time_attn)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) if pos_bias else None

self.has_feed_forward = add_feed_forward
Expand Down Expand Up @@ -545,7 +547,8 @@ def __init__(
attn_heads = 8,
condition_on_timestep = True,
attn_pos_bias = True,
flash_attn = False
flash_attn = False,
causal_time_attn = False
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
Expand Down Expand Up @@ -580,7 +583,8 @@ def __init__(
dim_head = attn_dim_head,
heads = attn_heads,
pos_bias = attn_pos_bias,
flash= flash_attn
flash = flash_attn,
causal_time_attn = causal_time_attn
)

mid_dim = dims[-1]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.0',
version = '0.3.1',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 16561b3

Please sign in to comment.