-
Notifications
You must be signed in to change notification settings - Fork 31.5k
[FA2] Cast to correct dtype
#26560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA2] Cast to correct dtype
#26560
Changes from all commits
fba4e7f
d251825
48126ea
a2c1e4a
53c5f30
b4777ca
5539b25
e03b218
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1289,7 +1289,8 @@ def _check_and_enable_flash_attn_2( | |
|
|
||
| if torch_dtype is None: | ||
| logger.warning( | ||
| "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" | ||
| "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour. Make sure to " | ||
| "pass `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16` when initialising the model." | ||
| ) | ||
| elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: | ||
| raise ValueError( | ||
|
|
@@ -1319,6 +1320,8 @@ def _check_and_enable_flash_attn_2( | |
| "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." | ||
| ) | ||
| config._flash_attn_2_enabled = True | ||
|
|
||
| config._flash_attn_2_attention_dtype = torch_dtype if torch_dtype is not None else torch.float16 | ||
| return config | ||
|
|
||
| def enable_input_require_grads(self): | ||
|
|
@@ -2172,6 +2175,26 @@ def to(self, *args, **kwargs): | |
| " model has already been set to the correct devices and casted to the correct `dtype`." | ||
| ) | ||
| else: | ||
| # TODO: @younesbelkada find a better way to do this directly in `xxxFlashAttention` modules | ||
| # currently it is not possible to retrieve the original dtype for quantized models. | ||
| if getattr(self.config, "_flash_attn_2_enabled", False): | ||
| target_dtype = None | ||
|
|
||
| if "dtype" not in kwargs: | ||
| for arg in args: | ||
| if isinstance(arg, torch.dtype): | ||
| target_dtype = arg | ||
| break | ||
| else: | ||
| target_dtype = kwargs["dtype"] | ||
|
|
||
| if target_dtype is not None and target_dtype == torch.float32: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| raise ValueError( | ||
| "You cannot cast a model that has been loaded with Flash Attention 2 in `float32`" | ||
| ) | ||
|
Comment on lines
+2192
to
+2194
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe mention how to go around this? |
||
| elif target_dtype is not None: | ||
| self.config._flash_attn_2_attention_dtype = target_dtype | ||
|
|
||
| return super().to(*args, **kwargs) | ||
|
|
||
| def half(self, *args): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -476,15 +476,17 @@ def forward( | |
| # in fp32. (LlamaRMSNorm handles it correctly) | ||
| input_dtype = query_states.dtype | ||
| if input_dtype == torch.float32: | ||
| attention_dtype = self.config._flash_attn_2_attention_dtype | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not all attention modules have a config variable. But guess that's ok we can just pass it forward. |
||
|
|
||
| logger.warning_once( | ||
| "The input hidden states seems to be silently casted in float32, this might be related to" | ||
| " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| " float16." | ||
| f"The input hidden states seems to be silently casted in float32, this might be related to" | ||
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| f" {attention_dtype}. Make sure to pass the desired dtype when calling `from_pretrained` with `torch_dtype=your_dtype`" | ||
| ) | ||
|
|
||
| query_states = query_states.to(torch.float16) | ||
| key_states = key_states.to(torch.float16) | ||
| value_states = value_states.to(torch.float16) | ||
| query_states = query_states.to(attention_dtype) | ||
| key_states = key_states.to(attention_dtype) | ||
| value_states = value_states.to(attention_dtype) | ||
|
|
||
| attn_output = self._flash_attention_forward( | ||
| query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -411,15 +411,16 @@ def forward( | |
| # cast them back in float16 just to be sure everything works as expected. | ||
| input_dtype = query_states.dtype | ||
| if input_dtype == torch.float32: | ||
| attention_dtype = self.config._flash_attn_2_attention_dtype | ||
|
|
||
| logger.warning_once( | ||
| "The input hidden states seems to be silently casted in float32, this might be related to" | ||
| " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| " float16." | ||
| f"The input hidden states seems to be silently casted in float32, this might be related to" | ||
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| f" {attention_dtype}. Make sure to pass the desired dtype when calling `from_pretrained` with `torch_dtype=your_dtype`" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to me, that this line is 124 long and should be split into 2, previous lines could be left without 'f' at the beginning, while there are no values to be formatted there |
||
| ) | ||
|
|
||
| query_states = query_states.to(torch.float16) | ||
| key_states = key_states.to(torch.float16) | ||
| value_states = value_states.to(torch.float16) | ||
| query_states = query_states.to(attention_dtype) | ||
| key_states = key_states.to(attention_dtype) | ||
| value_states = value_states.to(attention_dtype) | ||
|
|
||
| # Reashape to the expected shape for Flash Attention | ||
| query_states = query_states.transpose(1, 2) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be investigated now?