Skip to content

Commit

Permalink
enable autoTP for MPT (#3861)
Browse files Browse the repository at this point in the history
* enable autoTP for MPT

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add model specific func to auto_tp_model_utils.py

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
sywangyi and jeffra authored Jul 27, 2023
1 parent 76953a3 commit 0bafeac
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 56 deletions.
61 changes: 8 additions & 53 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
import time
import os

from deepspeed import comm as dist
from deepspeed.utils.logging import log_dist

Expand All @@ -27,65 +26,14 @@
from ..module_inject.auto_tp import AutoTP

from ..module_inject.replace_policy import generic_policies
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor

DS_INFERENCE_ENABLED = False
from torch import nn

INFERENCE_MODEL_TIMER = "model-forward-inference"


def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
import math
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != num_heads:
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)


class InferenceEngine(Module):
inference_mp_group = None
inference_ep_group = None
Expand Down Expand Up @@ -146,6 +94,7 @@ def __init__(self, model, config):
# This is a hack to redefine the alibi func due to TP
if config.tensor_parallel.tp_size > 1:
self.build_alibi_tensor()
self.build_attn_bias()

if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph:
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
Expand Down Expand Up @@ -239,6 +188,12 @@ def build_alibi_tensor(self):
if hasattr(self.module.transformer, 'build_alibi_tensor'):
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor

def build_attn_bias(self):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, '_attn_bias'):
self.module.transformer._attn_bias_orig = self.module.transformer._attn_bias
self.module.transformer.__class__._attn_bias = build_mpt_atten_bias_tensor

def _pre_forward_hook(self, module, *inputs, **kwargs):
if self.use_cuda_events:
self.timers(INFERENCE_MODEL_TIMER).start()
Expand Down
78 changes: 78 additions & 0 deletions deepspeed/module_inject/auto_tp_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed import comm as dist
import torch
from typing import Optional


def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
import math
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != num_heads:
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)


def build_mpt_atten_bias_tensor(self,
device,
dtype,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None):
(attn_bias, attention_mask) = self._attn_bias_orig(device,
dtype,
attention_mask=attention_mask,
prefix_mask=prefix_mask,
sequence_id=sequence_id)
if dist.is_initialized():
num_heads_per_rank = int(self.config.n_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask
22 changes: 19 additions & 3 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,22 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
def _replace(child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
weight_shape = child.weight.shape
if name == 'attn.Wqkv' and module._get_name() == 'MPTBlock':
# MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size]
# instead of [num_head,3,head_dim,hidden_size]
new_weight = torch.empty((
weight_shape[0] // mp_size,
weight_shape[1],
),
device=child.weight.device,
dtype=child.weight.dtype)
reversed_dim = True
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, out_dim=0)
mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim)
setattr(child, "replaced", True)
return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None)
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
if name in all_reduce_linears:
new_weight = torch.empty((
weight_shape[1] if conv_linear_layer else weight_shape[0],
Expand Down Expand Up @@ -461,7 +475,8 @@ def _replace_module(r_module, prev_name='', prev_class_name=''):
else:
class_name = prev_class_name + '.' + prev_name
checking_key = prefix + '.' + class_name + '.' + name + '.' if class_name != "" else prefix + '.' + name + '.'
if child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm] and state_dict is not None:
if (child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm]
or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and state_dict is not None:
if any(checking_key in item for item in state_dict):
load(child, state_dict, checking_key, mp_group)
else:
Expand Down Expand Up @@ -836,7 +851,8 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di
layer_id += 1
else:
checking_key = prefix + name + '.'
if child.__class__ in load_layers and state_dict is not None:
if (child.__class__ in load_layers
or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and state_dict is not None:
if any(checking_key in item for item in state_dict):
load(
child,
Expand Down
1 change: 1 addition & 0 deletions docs/_tutorials/automatic-tensor-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ The following model families have been successfully tested with automatic tensor
- xlm_roberta
- yoso
- bloom
- mpt

# Unsupported Models

Expand Down

0 comments on commit 0bafeac

Please sign in to comment.