Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,8 @@ def _check_and_enable_flash_attn_2(

if torch_dtype is None:
logger.warning(
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour. Make sure to "
"pass `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16` when initialising the model."
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
Expand Down Expand Up @@ -1319,6 +1320,8 @@ def _check_and_enable_flash_attn_2(
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
config._flash_attn_2_enabled = True

config._flash_attn_2_attention_dtype = torch_dtype if torch_dtype is not None else torch.float16
return config

def enable_input_require_grads(self):
Expand Down Expand Up @@ -2172,6 +2175,26 @@ def to(self, *args, **kwargs):
" model has already been set to the correct devices and casted to the correct `dtype`."
)
else:
# TODO: @younesbelkada find a better way to do this directly in `xxxFlashAttention` modules
# currently it is not possible to retrieve the original dtype for quantized models.
Comment on lines +2178 to +2179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be investigated now?

if getattr(self.config, "_flash_attn_2_enabled", False):
target_dtype = None

if "dtype" not in kwargs:
for arg in args:
if isinstance(arg, torch.dtype):
target_dtype = arg
break
else:
target_dtype = kwargs["dtype"]

if target_dtype is not None and target_dtype == torch.float32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

raise ValueError(
"You cannot cast a model that has been loaded with Flash Attention 2 in `float32`"
)
Comment on lines +2192 to +2194
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention how to go around this?

elif target_dtype is not None:
self.config._flash_attn_2_attention_dtype = target_dtype

return super().to(*args, **kwargs)

def half(self, *args):
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,15 +613,17 @@ def forward(
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype
if input_dtype == torch.float32:
attention_dtype = self.config._flash_attn_2_attention_dtype

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {attention_dtype}. Make sure to pass the desired dtype when calling `from_pretrained` with `torch_dtype=your_dtype`"
)

query_layer = query_layer.to(torch.float16)
key_layer = key_layer.to(torch.float16)
value_layer = value_layer.to(torch.float16)
query_layer = query_layer.to(attention_dtype)
key_layer = key_layer.to(attention_dtype)
value_layer = value_layer.to(attention_dtype)

attn_output = self._flash_attention_forward(
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,15 +476,17 @@ def forward(
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
attention_dtype = self.config._flash_attn_2_attention_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all attention modules have a config variable. But guess that's ok we can just pass it forward.


logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {attention_dtype}. Make sure to pass the desired dtype when calling `from_pretrained` with `torch_dtype=your_dtype`"
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(attention_dtype)
key_states = key_states.to(attention_dtype)
value_states = value_states.to(attention_dtype)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,15 +411,16 @@ def forward(
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
attention_dtype = self.config._flash_attn_2_attention_dtype

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {attention_dtype}. Make sure to pass the desired dtype when calling `from_pretrained` with `torch_dtype=your_dtype`"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me, that this line is 124 long and should be split into 2, previous lines could be left without 'f' at the beginning, while there are no values to be formatted there

)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(attention_dtype)
key_states = key_states.to(attention_dtype)
value_states = value_states.to(attention_dtype)

# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
Expand Down