Skip to content

Commit

Permalink
Support non-causal ALiBi
Browse files Browse the repository at this point in the history
Using the symmetric solution from
ofirpress/attention_with_linear_biases#5.
  • Loading branch information
janEbert committed Aug 14, 2023
1 parent 4e99bb4 commit 50f9615
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ def __init__(self, config,
args.seq_length,
args.num_attention_heads,
args.micro_batch_size,
self_attn_mask_type,
).to(torch.cuda.current_device())
if args.params_dtype is torch.float16:
self.alibi = self.alibi.to(torch.float16)
Expand Down Expand Up @@ -1373,11 +1374,16 @@ def forward(self, hidden_states, attention_mask,
return output

@staticmethod
def _build_alibi_tensor(max_seq_len, num_attention_heads, batch_size):
# Copied from bigscience-workshop/Megatron-DeepSpeed
def _build_alibi_tensor(
max_seq_len, num_attention_heads, batch_size, attn_mask_type):
# Adjusted from bigscience-workshop/Megatron-DeepSpeed
# Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
# Also based on the first (symmetric) option from
# https://github.com/ofirpress/attention_with_linear_biases/issues/5.
"""Returns tensor shaped
(1, num_attention_heads_per_partition, 1, max_seq_len),
(1, num_attention_heads_per_partition, 1 or max_seq_len, max_seq_len),
where "1 or max_seq_len" is chosen depending on whether
attention is causal or bi-directional, respectively.
"""

def get_slopes(n):
Expand All @@ -1398,11 +1404,22 @@ def get_slopes_power_of_2(n):
)

slopes = torch.Tensor(get_slopes(num_attention_heads))
alibi = (
slopes.unsqueeze(1).unsqueeze(1)
* torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(
num_attention_heads, -1, -1)
)
if attn_mask_type is AttnMaskType.causal:
alibi = (
slopes.unsqueeze(1).unsqueeze(1)
* torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(
num_attention_heads, -1, -1)
)
else:
context_position = torch.arange(max_seq_len)[:, None]
memory_position = torch.arange(max_seq_len)[None, :]
relative_position = memory_position - context_position
relative_position = torch.abs(
relative_position,
).unsqueeze(0).expand(num_attention_heads, -1, -1)

alibi = torch.empty(num_attention_heads, max_seq_len, max_seq_len)
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position

# Select the part of the tensor that corresponds to our tensor
# parallel index.
Expand Down

0 comments on commit 50f9615

Please sign in to comment.