Skip to content

Commit ce42aa7

Browse files
committed
Update modeling_flash_attention_utils.py
1 parent d874b2d commit ce42aa7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _upad_input(
150150
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
151151

152152
# With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
153-
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions
153+
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
154154
if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
155155
key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
156156

0 commit comments

Comments
 (0)