-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* enable autoTP for MPT Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add model specific func to auto_tp_model_utils.py Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
- Loading branch information
Showing
4 changed files
with
106 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from deepspeed import comm as dist | ||
import torch | ||
from typing import Optional | ||
|
||
|
||
def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: | ||
""" | ||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it | ||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value | ||
`softmax(l+a) = softmax(l)`. Based on | ||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 | ||
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. | ||
Args: | ||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) | ||
attention_mask (`torch.Tensor`): | ||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len). | ||
num_heads (`int`, *required*): | ||
number of heads | ||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): | ||
dtype of the output tensor | ||
""" | ||
import math | ||
batch_size, seq_length = attention_mask.shape | ||
closest_power_of_2 = 2**math.floor(math.log2(num_heads)) | ||
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), | ||
device=attention_mask.device, | ||
dtype=torch.float32) | ||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) | ||
slopes = torch.pow(base, powers) | ||
|
||
if closest_power_of_2 != num_heads: | ||
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), | ||
device=attention_mask.device, | ||
dtype=torch.float32) | ||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) | ||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) | ||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) | ||
|
||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention | ||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) | ||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) | ||
# => the query_length dimension will then be broadcasted correctly | ||
# This is more or less identical to T5's relative position bias: | ||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 | ||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] | ||
alibi = slopes[..., None] * arange_tensor | ||
if dist.is_initialized(): | ||
num_heads_per_rank = int(num_heads / dist.get_world_size()) | ||
offset = dist.get_rank() * num_heads_per_rank | ||
alibi = alibi.view(batch_size, num_heads, 1, seq_length) | ||
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] | ||
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) | ||
else: | ||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) | ||
|
||
|
||
def build_mpt_atten_bias_tensor(self, | ||
device, | ||
dtype, | ||
attention_mask: Optional[torch.ByteTensor] = None, | ||
prefix_mask: Optional[torch.ByteTensor] = None, | ||
sequence_id: Optional[torch.LongTensor] = None): | ||
(attn_bias, attention_mask) = self._attn_bias_orig(device, | ||
dtype, | ||
attention_mask=attention_mask, | ||
prefix_mask=prefix_mask, | ||
sequence_id=sequence_id) | ||
if dist.is_initialized(): | ||
num_heads_per_rank = int(self.config.n_heads / dist.get_world_size()) | ||
offset = dist.get_rank() * num_heads_per_rank | ||
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :] | ||
return attn_bias, attention_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters