Skip to content

Commit

Permalink
Flash-Attn: fix generation when no attention mask or no pading (#32241)
Browse files Browse the repository at this point in the history
* fix

* fix prev test (half of failures)

* [run-slow] llama, gemma2

* [run-slow] llama, gemma2
  • Loading branch information
zucchini-nlp authored Jul 26, 2024
1 parent 27c7f97 commit 81233c0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,11 @@ def _flash_attention_forward(
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and check not all examples (row) contain only 1 sequence,
# if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all():
elif (
position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1
):
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
Expand Down
15 changes: 14 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4270,6 +4270,18 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)

# Generate with one batch only to test generation when attention mask will be None
# when real inputs are used, because there is no padding. See issue #32237 for more
dummy_input = dummy_input[:1, ...]
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down Expand Up @@ -4342,6 +4354,8 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")

dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
Expand All @@ -4356,7 +4370,6 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
Expand Down

0 comments on commit 81233c0

Please sign in to comment.