Skip to content

Commit

Permalink
fix AutoTP in deepspeed could not work for bloom (huggingface#22196)
Browse files Browse the repository at this point in the history
* fix AutoTP in deepspeed could not work for bloom

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add a method in BloomModel to build ailib

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored and raghavanone committed Apr 5, 2023
1 parent a6a030d commit 2c73d7b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ def __init__(self, config: BloomConfig):
# Initialize weights and apply final processing
self.post_init()

def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
return build_alibi_tensor(attention_mask, num_heads, dtype)

def get_input_embeddings(self):
return self.word_embeddings

Expand Down Expand Up @@ -750,7 +753,7 @@ def forward(
else:
attention_mask = attention_mask.to(hidden_states.device)

alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)

causal_mask = self._prepare_attn_mask(
attention_mask,
Expand Down

0 comments on commit 2c73d7b

Please sign in to comment.