Skip to content

Commit

Permalink
Re-enable SDPA's FA2 path (#30070)
Browse files Browse the repository at this point in the history
* tentatively re-enable FA2 + SDPA

* better comment

* _ignore_causal_mask_sdpa as staticmethod

* type hints

* use past_seen_tokens instead

* enable copied from for sdpa

* ruff

* llama simplifications on review

* remove unnecessary self.is_causal check

* fix copies

* cleaning

* precise message

* better doc

* add test

* simplify

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
fxmarty and ArthurZucker authored Apr 17, 2024
1 parent e4ea19b commit 05bdef1
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 60 deletions.
91 changes: 59 additions & 32 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,59 @@ def _unmask_unattended(

return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))

@staticmethod
def _ignore_causal_mask_sdpa(
attention_mask: Optional[torch.Tensor],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
) -> bool:
"""
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
"""

batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
key_value_length = query_length + past_key_values_length

is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

ignore_causal_mask = False

if attention_mask is None:
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
#
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
elif sliding_window is None or key_value_length < sliding_window:
if len(attention_mask.shape) == 4:
expected_shape = (batch_size, 1, query_length, key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True

# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.

return ignore_causal_mask


def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
Expand Down Expand Up @@ -305,7 +358,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

key_value_length = input_shape[-1] + past_key_values_length
_, query_length = input_shape

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
Expand All @@ -316,37 +368,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

ignore_causal_mask = False

if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
elif sliding_window is None or key_value_length < sliding_window:
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask

elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
elif key_value_length == query_length:
ignore_causal_mask = True

# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)

if ignore_causal_mask:
expanded_4d_mask = None
Expand Down
36 changes: 27 additions & 9 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -908,9 +911,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -974,24 +975,41 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
36 changes: 27 additions & 9 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -888,9 +891,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -960,24 +961,41 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
37 changes: 27 additions & 10 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,6 @@ def forward(
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = attention_mask
# if attention_mask is not None and cache_position is not None:
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

Expand All @@ -667,12 +666,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -987,9 +989,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -1053,24 +1053,41 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
Loading

0 comments on commit 05bdef1

Please sign in to comment.