From 05bdef16b611df0946a6a602503f1ace604b6c80 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 17 Apr 2024 22:21:00 +0200 Subject: [PATCH] Re-enable SDPA's FA2 path (#30070) * tentatively re-enable FA2 + SDPA * better comment * _ignore_causal_mask_sdpa as staticmethod * type hints * use past_seen_tokens instead * enable copied from for sdpa * ruff * llama simplifications on review * remove unnecessary self.is_causal check * fix copies * cleaning * precise message * better doc * add test * simplify * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 91 ++++++++++++------- .../models/cohere/modeling_cohere.py | 36 ++++++-- .../models/gemma/modeling_gemma.py | 36 ++++++-- .../models/llama/modeling_llama.py | 37 ++++++-- tests/test_modeling_common.py | 36 ++++++++ 5 files changed, 176 insertions(+), 60 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 43da8917b23075..8ae9b57b6c43be 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -234,6 +234,59 @@ def _unmask_unattended( return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + + batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or + # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). + if sliding_window is None or key_value_length < sliding_window: + ignore_causal_mask = not is_tracing + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + expected_shape = (batch_size, 1, query_length, key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + def _prepare_4d_causal_attention_mask( attention_mask: Optional[torch.Tensor], @@ -305,7 +358,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length - _, query_length = input_shape # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. @@ -316,37 +368,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) - ignore_causal_mask = False - - if attention_mask is None: - if sliding_window is None or key_value_length < sliding_window: - ignore_causal_mask = not is_tracing - elif sliding_window is None or key_value_length < sliding_window: - # 4d mask is passed through - if len(attention_mask.shape) == 4: - expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - else: - # if the 4D mask has correct shape - invert it and fill with negative infinity - inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) - attention_mask = inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min - ) - return attention_mask - - elif not is_tracing and torch.all(attention_mask == 1): - if query_length == 1: - # For query_length == 1, causal attention and bi-directional attention are the same. - ignore_causal_mask = True - elif key_value_length == query_length: - ignore_causal_mask = True - - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 + ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) if ignore_causal_mask: expanded_4d_mask = None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 95a7d768273eeb..950d45ea867a30 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -590,12 +590,15 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -908,9 +911,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -974,16 +975,31 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -991,7 +1007,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c8b9b11c557972..6077259d0b0fac 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -570,12 +570,15 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -888,9 +891,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -960,16 +961,31 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -977,7 +993,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e1afb61be0dfc6..2b8e8f6d0958dd 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -656,7 +656,6 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] @@ -667,12 +666,15 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -987,9 +989,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -1053,16 +1053,31 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -1070,7 +1085,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fd23a3f5ee9ffa..5e6be8c04044e9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3772,6 +3772,42 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_can_dispatch_on_flash(self): + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability + + if not torch.version.cuda or major < 8: + self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0") + + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.model_type in ["llava", "llava_next", "vipllava"]: + self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input") + if config.model_type in ["idefics"]: + self.skipTest("Idefics currently (transformers==4.39.1) requires an image_attention_mask input") + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa") + model.to(torch_device) + + inputs_dict.pop("attention_mask", None) + inputs_dict.pop("decoder_attention_mask", None) + + for name, inp in inputs_dict.items(): + if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: + inputs_dict[name] = inp.to(torch.float16) + + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + _ = model(**inputs_dict) + @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self):