@@ -190,7 +190,7 @@ def _upad_input(
190
190
)
191
191
192
192
193
- def _prepare_from_posids (query , key , value , position_ids ):
193
+ def _prepare_from_posids (query , key , value , position_ids , query_length ):
194
194
"""
195
195
This function returns necessary arguments to call `flash_attn_varlen_func`.
196
196
All three query, key, value states will be flattened.
@@ -205,43 +205,66 @@ def _prepare_from_posids(query, key, value, position_ids):
205
205
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
206
206
position_ids (`torch.Tensor`):
207
207
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
208
+ query_length (`int`):
209
+ Sequence length of the input queries.
208
210
Return:
209
211
query (`torch.Tensor`):
210
212
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
211
213
key (`torch.Tensor`):
212
214
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
213
215
value (`torch.Tensor`):
214
216
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
215
- indices_q (`torch.Tensor`):
216
- The indices of non-masked tokens from the flattened input target sequence.
217
217
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
218
218
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
219
219
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
220
220
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
221
221
"""
222
+ kv_length = key .shape [1 ]
222
223
query = query .contiguous ().view (- 1 , query .size (- 2 ), query .size (- 1 ))
223
224
key = key .contiguous ().view (- 1 , key .size (- 2 ), key .size (- 1 ))
224
225
value = value .contiguous ().view (- 1 , value .size (- 2 ), value .size (- 1 ))
225
226
226
- position_ids = position_ids .flatten ()
227
- indices_q = torch .arange (position_ids .size (0 ), device = position_ids .device , dtype = torch .int32 )
227
+ # If the lengths are not equal, most probably we are in decoding stage with cache
228
+ # In that case the position ids will not always start with `0` and we need a better way to infer
229
+ # cumulative seq lengths.
230
+ if query_length != kv_length :
231
+ indices_q = torch .arange (position_ids .size (0 ), device = position_ids .device , dtype = torch .int32 )
228
232
229
- cu_seq_lens = torch .cat (
230
- (
231
- indices_q [position_ids == 0 ],
232
- torch .tensor (position_ids .size (), device = position_ids .device , dtype = torch .int32 ),
233
+ tensor_kws = {"dtype" : torch .int32 , "device" : position_ids .device }
234
+ last_position_ids = position_ids [:, - 1 ]
235
+
236
+ cu_seq_lens_k = torch .cat (
237
+ [torch .zeros (1 , ** tensor_kws ), last_position_ids .cumsum (0 ).add (1 ).to (torch .int32 )], 0
233
238
)
234
- )
235
- # NOTE: With torch compile, this will cause a graph break if you don't set
236
- # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
237
- # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
238
- # This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
239
- # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
240
- # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
241
- # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
242
- # for some models (e.g. qwen2-vl).
243
- max_length = cu_seq_lens .diff ().max ().item ()
244
- return (query , key , value , indices_q , (cu_seq_lens , cu_seq_lens ), (max_length , max_length ))
239
+ max_length_k = int (last_position_ids .max ()) + 1
240
+
241
+ batch_size , seq_len = query .shape [:2 ]
242
+ q_len = torch .ones (batch_size , ** tensor_kws ) if query_length == 1 else last_position_ids .add (1 )
243
+ cu_seq_lens_q = torch .cat ([torch .zeros (1 , ** tensor_kws ), q_len .cumsum (0 ).to (torch .int32 )], 0 )
244
+ max_length_q = int (q_len .max ())
245
+ else :
246
+ position_ids = position_ids .flatten ()
247
+ indices_q = torch .arange (position_ids .size (0 ), device = position_ids .device , dtype = torch .int32 )
248
+
249
+ cu_seq_lens_q = torch .cat (
250
+ (
251
+ indices_q [position_ids == 0 ],
252
+ torch .tensor (position_ids .size (), device = position_ids .device , dtype = torch .int32 ),
253
+ )
254
+ )
255
+ cu_seq_lens_k = cu_seq_lens_q
256
+
257
+ # NOTE: With torch compile, this will cause a graph break if you don't set
258
+ # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
259
+ # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
260
+ # This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
261
+ # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
262
+ # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
263
+ # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
264
+ # for some models (e.g. qwen2-vl).
265
+ max_length_q = cu_seq_lens_q .diff ().max ().item ()
266
+ max_length_k = max_length_q
267
+ return (query , key , value , (cu_seq_lens_q , cu_seq_lens_k ), (max_length_q , max_length_k ))
245
268
246
269
247
270
def _prepare_flash_attention_from_position_ids (query , key , value , position_ids ):
@@ -430,8 +453,8 @@ def _flash_attention_forward(
430
453
raise ValueError (
431
454
"Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
432
455
)
433
- q , k , v , idx , (cu_q , cu_k ), (mq , mk ) = _prepare_from_posids (
434
- query_states , key_states , value_states , position_ids
456
+ q , k , v , (cu_q , cu_k ), (mq , mk ) = _prepare_from_posids (
457
+ query_states , key_states , value_states , position_ids , query_length = query_length
435
458
)
436
459
else :
437
460
q = query_states .reshape (- 1 , query_states .size (- 2 ), query_states .size (- 1 ))
0 commit comments