Skip to content

Commit ccb2e0e

Browse files
authored
Fix GPT2 with cross attention (#39754)
* fix * use new mask API * style * fix copies and attention tests * fix head pruning tests
1 parent dfd616e commit ccb2e0e

File tree

4 files changed

+89
-178
lines changed

4 files changed

+89
-178
lines changed

src/transformers/models/decision_transformer/modeling_decision_transformer.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -268,53 +268,62 @@ def forward(
268268
**kwargs,
269269
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
270270
is_cross_attention = encoder_hidden_states is not None
271+
if past_key_value is not None:
272+
if isinstance(past_key_value, EncoderDecoderCache):
273+
is_updated = past_key_value.is_updated.get(self.layer_idx)
274+
if is_cross_attention:
275+
# after the first generated id, we can subsequently re-use all key/value_layer from cache
276+
curr_past_key_value = past_key_value.cross_attention_cache
277+
else:
278+
curr_past_key_value = past_key_value.self_attention_cache
279+
else:
280+
curr_past_key_value = past_key_value
281+
271282
if is_cross_attention:
272283
if not hasattr(self, "q_attn"):
273284
raise ValueError(
274285
"If class is used as cross attention, the weights `q_attn` have to be defined. "
275286
"Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
276287
)
277-
278288
query_states = self.q_attn(hidden_states)
279-
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
280289
attention_mask = encoder_attention_mask
290+
291+
# Try to get key/value states from cache if possible
292+
if past_key_value is not None and is_updated:
293+
key_states = curr_past_key_value.layers[self.layer_idx].keys
294+
value_states = curr_past_key_value.layers[self.layer_idx].values
295+
else:
296+
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
297+
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
298+
key_states = key_states.view(shape_kv).transpose(1, 2)
299+
value_states = value_states.view(shape_kv).transpose(1, 2)
281300
else:
282301
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
302+
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
303+
key_states = key_states.view(shape_kv).transpose(1, 2)
304+
value_states = value_states.view(shape_kv).transpose(1, 2)
283305

284306
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
285-
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
286-
287307
query_states = query_states.view(shape_q).transpose(1, 2)
288-
key_states = key_states.view(shape_kv).transpose(1, 2)
289-
value_states = value_states.view(shape_kv).transpose(1, 2)
290308

291-
if past_key_value is not None:
292-
if isinstance(past_key_value, EncoderDecoderCache):
293-
if is_cross_attention:
294-
past_key_value = past_key_value.cross_attention_cache
295-
else:
296-
past_key_value = past_key_value.self_attention_cache
297-
cache_kwargs = {"cache_position": cache_position}
298-
key_states, value_states = past_key_value.update(
299-
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
309+
if (past_key_value is not None and not is_cross_attention) or (
310+
past_key_value is not None and is_cross_attention and not is_updated
311+
):
312+
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
313+
cache_position = cache_position if not is_cross_attention else None
314+
key_states, value_states = curr_past_key_value.update(
315+
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
300316
)
317+
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
318+
if is_cross_attention:
319+
past_key_value.is_updated[self.layer_idx] = True
301320

302321
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
303322

304323
using_eager = self.config._attn_implementation == "eager"
305324
attention_interface: Callable = eager_attention_forward
306325
if self.config._attn_implementation != "eager":
307-
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
308-
using_eager = True
309-
logger.warning_once(
310-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
311-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
312-
)
313-
else:
314-
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
315-
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
316-
# not necessarily to eager (if mentioned options are provided).
317-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
326+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
318327

319328
if using_eager and self.reorder_and_upcast_attn:
320329
attn_output, attn_weights = self._upcast_and_reordered_attn(

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 46 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2828

2929
from ...activations import ACT2FN, get_activation
30-
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
30+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
3131
from ...generation import GenerationMixin
32-
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
32+
from ...masking_utils import create_causal_mask
33+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
3334
from ...modeling_layers import GradientCheckpointingLayer
3435
from ...modeling_outputs import (
3536
BaseModelOutputWithPastAndCrossAttentions,
@@ -278,53 +279,62 @@ def forward(
278279
**kwargs,
279280
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
280281
is_cross_attention = encoder_hidden_states is not None
282+
if past_key_value is not None:
283+
if isinstance(past_key_value, EncoderDecoderCache):
284+
is_updated = past_key_value.is_updated.get(self.layer_idx)
285+
if is_cross_attention:
286+
# after the first generated id, we can subsequently re-use all key/value_layer from cache
287+
curr_past_key_value = past_key_value.cross_attention_cache
288+
else:
289+
curr_past_key_value = past_key_value.self_attention_cache
290+
else:
291+
curr_past_key_value = past_key_value
292+
281293
if is_cross_attention:
282294
if not hasattr(self, "q_attn"):
283295
raise ValueError(
284296
"If class is used as cross attention, the weights `q_attn` have to be defined. "
285297
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
286298
)
287-
288299
query_states = self.q_attn(hidden_states)
289-
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
290300
attention_mask = encoder_attention_mask
301+
302+
# Try to get key/value states from cache if possible
303+
if past_key_value is not None and is_updated:
304+
key_states = curr_past_key_value.layers[self.layer_idx].keys
305+
value_states = curr_past_key_value.layers[self.layer_idx].values
306+
else:
307+
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
308+
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
309+
key_states = key_states.view(shape_kv).transpose(1, 2)
310+
value_states = value_states.view(shape_kv).transpose(1, 2)
291311
else:
292312
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
313+
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
314+
key_states = key_states.view(shape_kv).transpose(1, 2)
315+
value_states = value_states.view(shape_kv).transpose(1, 2)
293316

294317
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
295-
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
296-
297318
query_states = query_states.view(shape_q).transpose(1, 2)
298-
key_states = key_states.view(shape_kv).transpose(1, 2)
299-
value_states = value_states.view(shape_kv).transpose(1, 2)
300319

301-
if past_key_value is not None:
302-
if isinstance(past_key_value, EncoderDecoderCache):
303-
if is_cross_attention:
304-
past_key_value = past_key_value.cross_attention_cache
305-
else:
306-
past_key_value = past_key_value.self_attention_cache
307-
cache_kwargs = {"cache_position": cache_position}
308-
key_states, value_states = past_key_value.update(
309-
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
320+
if (past_key_value is not None and not is_cross_attention) or (
321+
past_key_value is not None and is_cross_attention and not is_updated
322+
):
323+
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
324+
cache_position = cache_position if not is_cross_attention else None
325+
key_states, value_states = curr_past_key_value.update(
326+
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
310327
)
328+
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
329+
if is_cross_attention:
330+
past_key_value.is_updated[self.layer_idx] = True
311331

312332
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
313333

314334
using_eager = self.config._attn_implementation == "eager"
315335
attention_interface: Callable = eager_attention_forward
316336
if self.config._attn_implementation != "eager":
317-
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
318-
using_eager = True
319-
logger.warning_once(
320-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
321-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
322-
)
323-
else:
324-
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
325-
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
326-
# not necessarily to eager (if mentioned options are provided).
327-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
337+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
328338

329339
if using_eager and self.reorder_and_upcast_attn:
330340
attn_output, attn_weights = self._upcast_and_reordered_attn(
@@ -861,8 +871,14 @@ def forward(
861871
# ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
862872
if attention_mask is not None and attention_mask.ndim < 4:
863873
attention_mask = attention_mask.view(batch_size, -1)
864-
causal_mask = self._update_causal_mask(
865-
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
874+
875+
causal_mask = create_causal_mask(
876+
config=self.config,
877+
input_embeds=inputs_embeds,
878+
attention_mask=attention_mask,
879+
cache_position=cache_position,
880+
past_key_values=past_key_values,
881+
position_ids=position_ids,
866882
)
867883

868884
# If a 2D or 3D attention mask is provided for the cross-attention
@@ -903,9 +919,6 @@ def forward(
903919
# Model parallel
904920
if self.model_parallel:
905921
torch.cuda.set_device(hidden_states.device)
906-
# Ensure that attention_mask is always on the same device as hidden_states
907-
if attention_mask is not None:
908-
attention_mask = attention_mask.to(hidden_states.device)
909922
if isinstance(head_mask, torch.Tensor):
910923
head_mask = head_mask.to(hidden_states.device)
911924
if output_hidden_states:
@@ -966,123 +979,6 @@ def forward(
966979
cross_attentions=all_cross_attentions,
967980
)
968981

969-
def _update_causal_mask(
970-
self,
971-
attention_mask: torch.Tensor,
972-
input_tensor: torch.Tensor,
973-
cache_position: torch.Tensor,
974-
past_key_values: Cache,
975-
output_attentions: bool,
976-
):
977-
if self.config._attn_implementation == "flash_attention_2":
978-
if attention_mask is not None and 0.0 in attention_mask:
979-
return attention_mask
980-
return None
981-
982-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
983-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
984-
# to infer the attention mask.
985-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
986-
using_static_cache = isinstance(past_key_values, StaticCache)
987-
988-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
989-
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
990-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
991-
attention_mask,
992-
inputs_embeds=input_tensor,
993-
past_key_values_length=past_seen_tokens,
994-
is_training=self.training,
995-
):
996-
return None
997-
998-
dtype = input_tensor.dtype
999-
sequence_length = input_tensor.shape[1]
1000-
if using_static_cache:
1001-
target_length = past_key_values.get_max_cache_shape()
1002-
else:
1003-
target_length = (
1004-
attention_mask.shape[-1]
1005-
if isinstance(attention_mask, torch.Tensor)
1006-
else past_seen_tokens + sequence_length + 1
1007-
)
1008-
1009-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1010-
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1011-
attention_mask,
1012-
sequence_length=sequence_length,
1013-
target_length=target_length,
1014-
dtype=dtype,
1015-
cache_position=cache_position,
1016-
batch_size=input_tensor.shape[0],
1017-
)
1018-
1019-
if (
1020-
self.config._attn_implementation == "sdpa"
1021-
and attention_mask is not None
1022-
and attention_mask.device.type == "cuda"
1023-
and not output_attentions
1024-
):
1025-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1026-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1027-
# Details: https://github.com/pytorch/pytorch/issues/110213
1028-
min_dtype = torch.finfo(dtype).min
1029-
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1030-
1031-
return causal_mask
1032-
1033-
@staticmethod
1034-
def _prepare_4d_causal_attention_mask_with_cache_position(
1035-
attention_mask: torch.Tensor,
1036-
sequence_length: int,
1037-
target_length: int,
1038-
dtype: torch.dtype,
1039-
cache_position: torch.Tensor,
1040-
batch_size: int,
1041-
**kwargs,
1042-
):
1043-
"""
1044-
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1045-
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1046-
1047-
Args:
1048-
attention_mask (`torch.Tensor`):
1049-
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1050-
`(batch_size, 1, query_length, key_value_length)`.
1051-
sequence_length (`int`):
1052-
The sequence length being processed.
1053-
target_length (`int`):
1054-
The target length: when generating with static cache, the mask should be as long as the static cache,
1055-
to account for the 0 padding, the part of the cache that is not filled yet.
1056-
dtype (`torch.dtype`):
1057-
The dtype to use for the 4D attention mask.
1058-
cache_position (`torch.Tensor`):
1059-
Indices depicting the position of the input sequence tokens in the sequence.
1060-
batch_size (`torch.Tensor`):
1061-
Batch size.
1062-
"""
1063-
if attention_mask is not None and attention_mask.dim() == 4:
1064-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1065-
causal_mask = attention_mask
1066-
else:
1067-
min_dtype = torch.finfo(dtype).min
1068-
causal_mask = torch.full(
1069-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
1070-
)
1071-
if sequence_length != 1:
1072-
causal_mask = torch.triu(causal_mask, diagonal=1)
1073-
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
1074-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1075-
if attention_mask is not None:
1076-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1077-
mask_length = attention_mask.shape[-1]
1078-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1079-
padding_mask = padding_mask == 0
1080-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1081-
padding_mask, min_dtype
1082-
)
1083-
1084-
return causal_mask
1085-
1086982

1087983
@auto_docstring(
1088984
custom_intro="""

src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def forward(
449449
output_attentions: Optional[bool] = None,
450450
output_hidden_states: Optional[bool] = None,
451451
return_dict: Optional[bool] = None,
452+
cache_position: Optional[torch.LongTensor] = None,
452453
**kwargs,
453454
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
454455
r"""
@@ -561,6 +562,7 @@ def forward(
561562
use_cache=use_cache,
562563
past_key_values=past_key_values,
563564
return_dict=return_dict,
565+
cache_position=cache_position,
564566
**kwargs_decoder,
565567
)
566568

0 commit comments

Comments
 (0)