Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 0 additions & 141 deletions colossalai/shardformer/modeling/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
Expand All @@ -25,7 +24,6 @@
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
apply_rotary_pos_emb,
build_alibi_tensor,
)
from transformers.utils import logging
Expand Down Expand Up @@ -171,145 +169,6 @@ def forward(
return forward


def get_falcon_flash_attention_forward():
try:
pass
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention

def forward(
self: FalconAttention,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

batch_size, query_length, _, _ = query_layer.shape

query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)

kv_seq_len = key_layer.shape[-2]
if layer_past is not None:
kv_seq_len += layer_past[0].shape[-2]
if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, kv_length, head_dim]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=-2)
value_layer = torch.cat((past_value, value_layer), dim=-2)

kv_length = key_layer.shape[-2]
if use_cache:
present = (key_layer, value_layer)
else:
present = None

if alibi is None:
if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attention_mask,
0.0,
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attention_scores = None
else:
attention_scores = query_layer @ key_layer.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim)

attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
attn_output = attention_scores @ value_layer

attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)

attn_output = self.dense(attn_output)

if output_attentions:
return attn_output, present, attention_scores
else:
return attn_output, present
else:
if self._use_sdpa and not output_attentions and head_mask is None:
attn_output = F.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)

attn_output = self.dense(attn_output)
else:
matmul_result = query_layer @ key_layer.transpose(-1, -2)

# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)

# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)

attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)

if head_mask is not None:
attention_probs = attention_probs * head_mask

# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)

# matmul: [batch_size * num_heads, q_length, head_dim]
attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)

# change view [batch_size, q_length, num_heads * head_dim]
attn_output = self._merge_heads(attn_output)

attn_output = self.dense(attn_output)

if output_attentions:
return attn_output, present, attention_probs
else:
return attn_output, present

return forward


class FalconPipelineForwards:
"""
This class serves as a micro library for falcon pipeline forwards.
Expand Down
5 changes: 0 additions & 5 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from colossalai.shardformer.shard import ShardConfig

from ..layer import ColoAttention, cross_entropy_1d
from ..layer._operation import gather_forward_split_backward


class LlamaPipelineForwards:
Expand Down Expand Up @@ -115,10 +114,6 @@ def llama_model_forward(
)
position_ids = position_ids.unsqueeze(0)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None

# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
Expand Down
5 changes: 1 addition & 4 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,6 @@ def forward(

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output, past_key_value

return forward
14 changes: 0 additions & 14 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,6 @@ class OPTPipelineForwards:
under pipeline setting.
"""

@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len

expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expand Down
5 changes: 0 additions & 5 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,6 @@ def module_policy(self):

if self.shard_config.enable_flash_attention:
warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.")
# self.append_or_create_method_replacement(
# description={"forward": get_falcon_flash_attention_forward()},
# policy=policy,
# target_key=FalconAttention,
# )
return policy

def postprocess(self):
Expand Down