Skip to content

Commit d751bc1

Browse files
committed
Implement causal attention mask.
1 parent 73287c2 commit d751bc1

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tranception_pytorch/tranception_pytorch.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def __init__(
9292
self.v_d_conv7 = DepthwiseConvolution(head_dim=self.head_dim, kernel_size=7)
9393

9494
def forward(self, x, alibi):
95+
seq_len = x.size(-2)
96+
9597
q = self.to_q(x)
9698
k = self.to_k(x)
9799
v = self.to_v(x)
@@ -113,12 +115,12 @@ def forward(self, x, alibi):
113115

114116
q, k, v = map(lambda t: rearrange(t, 'b k n2 l d -> b (k n2) l d', k=4), (q, k, v))
115117

116-
# Scaled dot product attention + ALiBi position encoding.
117-
logit = einsum('b n i d, b n j d -> b n i j', q, k) * (self.head_dim ** -0.5) + alibi
118+
# Scaled dot product attention + ALiBi position encoding + causal attention masking
119+
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * (-1e-9)
120+
causal_mask = causal_mask.view(1, 1, seq_len, seq_len)
121+
causal_mask = causal_mask.to(x.device)
118122

119-
#
120-
# TODO: causal masking in decoder
121-
#
123+
logit = einsum('b n i d, b n j d -> b n i j', q, k) * (self.head_dim ** -0.5) + alibi + causal_mask
122124

123125
attn = logit.softmax(dim=-1)
124126

0 commit comments

Comments
 (0)