Skip to content

Commit 7892257

Browse files
authored
FA2 can continue generation from cache (#39843)
* add fa2 support to continue generation from cache * update q-len
1 parent 9bfbdd2 commit 7892257

File tree

3 files changed

+132
-40
lines changed

3 files changed

+132
-40
lines changed

src/transformers/generation/utils.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -677,24 +677,6 @@ def prepare_inputs_for_generation(
677677
if encoder_attention_mask is not None:
678678
model_inputs["attention_mask"] = encoder_attention_mask
679679

680-
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
681-
tensor_kws = {"dtype": torch.int32, "device": self.device}
682-
pos = model_inputs["position_ids"][:, -1]
683-
684-
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
685-
max_length_k = int(pos.max()) + 1
686-
687-
bs, seq_len = input_ids.size()
688-
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
689-
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
690-
max_length_q = int(q_len.max())
691-
692-
model_inputs.update(
693-
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
694-
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
695-
max_length_q=max_length_q,
696-
max_length_k=max_length_k,
697-
)
698680
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
699681
for key, value in kwargs.items():
700682
if key not in model_inputs:

src/transformers/modeling_flash_attention_utils.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _upad_input(
190190
)
191191

192192

193-
def _prepare_from_posids(query, key, value, position_ids):
193+
def _prepare_from_posids(query, key, value, position_ids, query_length):
194194
"""
195195
This function returns necessary arguments to call `flash_attn_varlen_func`.
196196
All three query, key, value states will be flattened.
@@ -205,43 +205,66 @@ def _prepare_from_posids(query, key, value, position_ids):
205205
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
206206
position_ids (`torch.Tensor`):
207207
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
208+
query_length (`int`):
209+
Sequence length of the input queries.
208210
Return:
209211
query (`torch.Tensor`):
210212
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
211213
key (`torch.Tensor`):
212214
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
213215
value (`torch.Tensor`):
214216
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
215-
indices_q (`torch.Tensor`):
216-
The indices of non-masked tokens from the flattened input target sequence.
217217
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
218218
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
219219
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
220220
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
221221
"""
222+
kv_length = key.shape[1]
222223
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
223224
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
224225
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
225226

226-
position_ids = position_ids.flatten()
227-
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
227+
# If the lengths are not equal, most probably we are in decoding stage with cache
228+
# In that case the position ids will not always start with `0` and we need a better way to infer
229+
# cumulative seq lengths.
230+
if query_length != kv_length:
231+
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
228232

229-
cu_seq_lens = torch.cat(
230-
(
231-
indices_q[position_ids == 0],
232-
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
233+
tensor_kws = {"dtype": torch.int32, "device": position_ids.device}
234+
last_position_ids = position_ids[:, -1]
235+
236+
cu_seq_lens_k = torch.cat(
237+
[torch.zeros(1, **tensor_kws), last_position_ids.cumsum(0).add(1).to(torch.int32)], 0
233238
)
234-
)
235-
# NOTE: With torch compile, this will cause a graph break if you don't set
236-
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
237-
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
238-
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
239-
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
240-
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
241-
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
242-
# for some models (e.g. qwen2-vl).
243-
max_length = cu_seq_lens.diff().max().item()
244-
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
239+
max_length_k = int(last_position_ids.max()) + 1
240+
241+
batch_size, seq_len = query.shape[:2]
242+
q_len = torch.ones(batch_size, **tensor_kws) if query_length == 1 else last_position_ids.add(1)
243+
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0).to(torch.int32)], 0)
244+
max_length_q = int(q_len.max())
245+
else:
246+
position_ids = position_ids.flatten()
247+
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
248+
249+
cu_seq_lens_q = torch.cat(
250+
(
251+
indices_q[position_ids == 0],
252+
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
253+
)
254+
)
255+
cu_seq_lens_k = cu_seq_lens_q
256+
257+
# NOTE: With torch compile, this will cause a graph break if you don't set
258+
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
259+
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
260+
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
261+
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
262+
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
263+
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
264+
# for some models (e.g. qwen2-vl).
265+
max_length_q = cu_seq_lens_q.diff().max().item()
266+
max_length_k = max_length_q
267+
return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
245268

246269

247270
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
@@ -430,8 +453,8 @@ def _flash_attention_forward(
430453
raise ValueError(
431454
"Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
432455
)
433-
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
434-
query_states, key_states, value_states, position_ids
456+
q, k, v, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
457+
query_states, key_states, value_states, position_ids, query_length=query_length
435458
)
436459
else:
437460
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))

tests/test_modeling_common.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4280,6 +4280,93 @@ def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa
42804280
attn_implementation="flash_attention_3", fa_kwargs=True
42814281
)
42824282

4283+
@require_flash_attn
4284+
@require_torch_gpu
4285+
@mark.flash_attn_test
4286+
def test_flash_attention_2_continue_generate_with_position_ids(self):
4287+
"""
4288+
Tests that the given attention implementation can work with packed sequences and infers the mask
4289+
from position ids. This test requires the model to use new attention mask API which handles packing.
4290+
"""
4291+
4292+
max_new_tokens = 2
4293+
for model_class in self.all_generative_model_classes:
4294+
if not model_class._supports_flash_attn:
4295+
self.skipTest(f"{model_class.__name__} does not support Flash Attention.")
4296+
4297+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
4298+
if config.is_encoder_decoder:
4299+
self.skipTest("Model is an encoder-decoder")
4300+
4301+
if not hasattr(config.get_text_config(), "use_cache"):
4302+
self.skipTest(f"{model_class.__name__} doesn't support caching")
4303+
4304+
if "input_ids" not in inputs_dict or inputs_dict["input_ids"].ndim != 2:
4305+
self.skipTest("Model dummy inputs should contain text input ids")
4306+
4307+
# make sure that all models have enough positions for generation
4308+
dummy_input_ids = inputs_dict["input_ids"]
4309+
if hasattr(config, "max_position_embeddings"):
4310+
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
4311+
4312+
model = model_class(config)
4313+
if "position_ids" not in inspect.signature(model.forward).parameters:
4314+
self.skipTest("Model does not support position_ids")
4315+
4316+
with tempfile.TemporaryDirectory() as tmpdirname:
4317+
model.save_pretrained(tmpdirname)
4318+
model = (
4319+
model_class.from_pretrained(
4320+
tmpdirname,
4321+
torch_dtype=torch.bfloat16,
4322+
attn_implementation="flash_attention_2",
4323+
)
4324+
.to(torch_device)
4325+
.eval()
4326+
)
4327+
4328+
# Drop all keys except for `input_ids`. Hard to manipulate with multimodals/head_mask/etc
4329+
dummy_input_ids = inputs_dict["input_ids"]
4330+
dummy_position_ids = torch.arange(dummy_input_ids.shape[1], device=torch_device)
4331+
dummy_position_ids = dummy_position_ids.unsqueeze(0).repeat(dummy_input_ids.shape[0], 1)
4332+
4333+
# Store cache for the input prompt
4334+
output = model(dummy_input_ids, position_ids=dummy_position_ids, use_cache=True)
4335+
if "past_key_values" not in output:
4336+
self.skipTest("This model doesn't return `past_key_values`")
4337+
4338+
# create new input_ids and position_ids to continue generation re-using the cache
4339+
new_input_ids = output.logits[:, -1, :].float().argmax(-1)[:, None]
4340+
past_length = dummy_input_ids.shape[1]
4341+
position_ids = torch.arange(past_length, past_length + new_input_ids.shape[1], device=torch_device)
4342+
position_ids = position_ids.unsqueeze(0).repeat(new_input_ids.shape[0], 1)
4343+
4344+
output = model(
4345+
input_ids=new_input_ids,
4346+
past_key_values=output.past_key_values,
4347+
position_ids=position_ids,
4348+
use_cache=True,
4349+
)
4350+
next_token_logits = output.logits[:, -1, :].float()
4351+
4352+
generate_kwargs = {
4353+
"pad_token_id": -1,
4354+
"eos_token_id": -1,
4355+
"forced_eos_token_id": None,
4356+
"use_cache": True,
4357+
"do_sample": False,
4358+
"return_dict_in_generate": True,
4359+
"output_logits": True,
4360+
"max_new_tokens": max_new_tokens,
4361+
}
4362+
generation_out = model.generate(dummy_input_ids, **generate_kwargs)
4363+
next_token_logits_from_generate = generation_out.logits[-1]
4364+
4365+
# acceptable numerical instability
4366+
# print(next_token_logits_from_generate, next_token_logits)
4367+
tol = torch.finfo(torch.bfloat16).eps
4368+
torch.testing.assert_close(next_token_logits_from_generate, next_token_logits, rtol=tol, atol=tol)
4369+
42834370
def flash_attn_from_config(self, attn_implementation: str):
42844371
r"""
42854372
Tests if the model can be loaded with `attn_implementation` from the config and if the

0 commit comments

Comments
 (0)