|
27 | 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 | 28 |
|
29 | 29 | from ...activations import ACT2FN, get_activation
|
30 |
| -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| 30 | +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache |
31 | 31 | from ...generation import GenerationMixin
|
32 |
| -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa |
| 32 | +from ...masking_utils import create_causal_mask |
| 33 | +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa |
33 | 34 | from ...modeling_layers import GradientCheckpointingLayer
|
34 | 35 | from ...modeling_outputs import (
|
35 | 36 | BaseModelOutputWithPastAndCrossAttentions,
|
@@ -278,53 +279,62 @@ def forward(
|
278 | 279 | **kwargs,
|
279 | 280 | ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
|
280 | 281 | is_cross_attention = encoder_hidden_states is not None
|
| 282 | + if past_key_value is not None: |
| 283 | + if isinstance(past_key_value, EncoderDecoderCache): |
| 284 | + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| 285 | + if is_cross_attention: |
| 286 | + # after the first generated id, we can subsequently re-use all key/value_layer from cache |
| 287 | + curr_past_key_value = past_key_value.cross_attention_cache |
| 288 | + else: |
| 289 | + curr_past_key_value = past_key_value.self_attention_cache |
| 290 | + else: |
| 291 | + curr_past_key_value = past_key_value |
| 292 | + |
281 | 293 | if is_cross_attention:
|
282 | 294 | if not hasattr(self, "q_attn"):
|
283 | 295 | raise ValueError(
|
284 | 296 | "If class is used as cross attention, the weights `q_attn` have to be defined. "
|
285 | 297 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
286 | 298 | )
|
287 |
| - |
288 | 299 | query_states = self.q_attn(hidden_states)
|
289 |
| - key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) |
290 | 300 | attention_mask = encoder_attention_mask
|
| 301 | + |
| 302 | + # Try to get key/value states from cache if possible |
| 303 | + if past_key_value is not None and is_updated: |
| 304 | + key_states = curr_past_key_value.layers[self.layer_idx].keys |
| 305 | + value_states = curr_past_key_value.layers[self.layer_idx].values |
| 306 | + else: |
| 307 | + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) |
| 308 | + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) |
| 309 | + key_states = key_states.view(shape_kv).transpose(1, 2) |
| 310 | + value_states = value_states.view(shape_kv).transpose(1, 2) |
291 | 311 | else:
|
292 | 312 | query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 313 | + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) |
| 314 | + key_states = key_states.view(shape_kv).transpose(1, 2) |
| 315 | + value_states = value_states.view(shape_kv).transpose(1, 2) |
293 | 316 |
|
294 | 317 | shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
295 |
| - shape_kv = (*key_states.shape[:-1], -1, self.head_dim) |
296 |
| - |
297 | 318 | query_states = query_states.view(shape_q).transpose(1, 2)
|
298 |
| - key_states = key_states.view(shape_kv).transpose(1, 2) |
299 |
| - value_states = value_states.view(shape_kv).transpose(1, 2) |
300 | 319 |
|
301 |
| - if past_key_value is not None: |
302 |
| - if isinstance(past_key_value, EncoderDecoderCache): |
303 |
| - if is_cross_attention: |
304 |
| - past_key_value = past_key_value.cross_attention_cache |
305 |
| - else: |
306 |
| - past_key_value = past_key_value.self_attention_cache |
307 |
| - cache_kwargs = {"cache_position": cache_position} |
308 |
| - key_states, value_states = past_key_value.update( |
309 |
| - key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs |
| 320 | + if (past_key_value is not None and not is_cross_attention) or ( |
| 321 | + past_key_value is not None and is_cross_attention and not is_updated |
| 322 | + ): |
| 323 | + # save all key/value_layer to cache to be re-used for fast auto-regressive generation |
| 324 | + cache_position = cache_position if not is_cross_attention else None |
| 325 | + key_states, value_states = curr_past_key_value.update( |
| 326 | + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
310 | 327 | )
|
| 328 | + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| 329 | + if is_cross_attention: |
| 330 | + past_key_value.is_updated[self.layer_idx] = True |
311 | 331 |
|
312 | 332 | is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
313 | 333 |
|
314 | 334 | using_eager = self.config._attn_implementation == "eager"
|
315 | 335 | attention_interface: Callable = eager_attention_forward
|
316 | 336 | if self.config._attn_implementation != "eager":
|
317 |
| - if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): |
318 |
| - using_eager = True |
319 |
| - logger.warning_once( |
320 |
| - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
321 |
| - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
322 |
| - ) |
323 |
| - else: |
324 |
| - # Attention functions are consistent with previous equivalent attention classes, however they do not support some options |
325 |
| - # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but |
326 |
| - # not necessarily to eager (if mentioned options are provided). |
327 |
| - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| 337 | + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
328 | 338 |
|
329 | 339 | if using_eager and self.reorder_and_upcast_attn:
|
330 | 340 | attn_output, attn_weights = self._upcast_and_reordered_attn(
|
@@ -861,8 +871,14 @@ def forward(
|
861 | 871 | # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
|
862 | 872 | if attention_mask is not None and attention_mask.ndim < 4:
|
863 | 873 | attention_mask = attention_mask.view(batch_size, -1)
|
864 |
| - causal_mask = self._update_causal_mask( |
865 |
| - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| 874 | + |
| 875 | + causal_mask = create_causal_mask( |
| 876 | + config=self.config, |
| 877 | + input_embeds=inputs_embeds, |
| 878 | + attention_mask=attention_mask, |
| 879 | + cache_position=cache_position, |
| 880 | + past_key_values=past_key_values, |
| 881 | + position_ids=position_ids, |
866 | 882 | )
|
867 | 883 |
|
868 | 884 | # If a 2D or 3D attention mask is provided for the cross-attention
|
@@ -903,9 +919,6 @@ def forward(
|
903 | 919 | # Model parallel
|
904 | 920 | if self.model_parallel:
|
905 | 921 | torch.cuda.set_device(hidden_states.device)
|
906 |
| - # Ensure that attention_mask is always on the same device as hidden_states |
907 |
| - if attention_mask is not None: |
908 |
| - attention_mask = attention_mask.to(hidden_states.device) |
909 | 922 | if isinstance(head_mask, torch.Tensor):
|
910 | 923 | head_mask = head_mask.to(hidden_states.device)
|
911 | 924 | if output_hidden_states:
|
@@ -966,123 +979,6 @@ def forward(
|
966 | 979 | cross_attentions=all_cross_attentions,
|
967 | 980 | )
|
968 | 981 |
|
969 |
| - def _update_causal_mask( |
970 |
| - self, |
971 |
| - attention_mask: torch.Tensor, |
972 |
| - input_tensor: torch.Tensor, |
973 |
| - cache_position: torch.Tensor, |
974 |
| - past_key_values: Cache, |
975 |
| - output_attentions: bool, |
976 |
| - ): |
977 |
| - if self.config._attn_implementation == "flash_attention_2": |
978 |
| - if attention_mask is not None and 0.0 in attention_mask: |
979 |
| - return attention_mask |
980 |
| - return None |
981 |
| - |
982 |
| - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
983 |
| - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
984 |
| - # to infer the attention mask. |
985 |
| - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
986 |
| - using_static_cache = isinstance(past_key_values, StaticCache) |
987 |
| - |
988 |
| - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
989 |
| - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
990 |
| - if AttentionMaskConverter._ignore_causal_mask_sdpa( |
991 |
| - attention_mask, |
992 |
| - inputs_embeds=input_tensor, |
993 |
| - past_key_values_length=past_seen_tokens, |
994 |
| - is_training=self.training, |
995 |
| - ): |
996 |
| - return None |
997 |
| - |
998 |
| - dtype = input_tensor.dtype |
999 |
| - sequence_length = input_tensor.shape[1] |
1000 |
| - if using_static_cache: |
1001 |
| - target_length = past_key_values.get_max_cache_shape() |
1002 |
| - else: |
1003 |
| - target_length = ( |
1004 |
| - attention_mask.shape[-1] |
1005 |
| - if isinstance(attention_mask, torch.Tensor) |
1006 |
| - else past_seen_tokens + sequence_length + 1 |
1007 |
| - ) |
1008 |
| - |
1009 |
| - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
1010 |
| - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
1011 |
| - attention_mask, |
1012 |
| - sequence_length=sequence_length, |
1013 |
| - target_length=target_length, |
1014 |
| - dtype=dtype, |
1015 |
| - cache_position=cache_position, |
1016 |
| - batch_size=input_tensor.shape[0], |
1017 |
| - ) |
1018 |
| - |
1019 |
| - if ( |
1020 |
| - self.config._attn_implementation == "sdpa" |
1021 |
| - and attention_mask is not None |
1022 |
| - and attention_mask.device.type == "cuda" |
1023 |
| - and not output_attentions |
1024 |
| - ): |
1025 |
| - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
1026 |
| - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
1027 |
| - # Details: https://github.com/pytorch/pytorch/issues/110213 |
1028 |
| - min_dtype = torch.finfo(dtype).min |
1029 |
| - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
1030 |
| - |
1031 |
| - return causal_mask |
1032 |
| - |
1033 |
| - @staticmethod |
1034 |
| - def _prepare_4d_causal_attention_mask_with_cache_position( |
1035 |
| - attention_mask: torch.Tensor, |
1036 |
| - sequence_length: int, |
1037 |
| - target_length: int, |
1038 |
| - dtype: torch.dtype, |
1039 |
| - cache_position: torch.Tensor, |
1040 |
| - batch_size: int, |
1041 |
| - **kwargs, |
1042 |
| - ): |
1043 |
| - """ |
1044 |
| - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
1045 |
| - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
1046 |
| -
|
1047 |
| - Args: |
1048 |
| - attention_mask (`torch.Tensor`): |
1049 |
| - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
1050 |
| - `(batch_size, 1, query_length, key_value_length)`. |
1051 |
| - sequence_length (`int`): |
1052 |
| - The sequence length being processed. |
1053 |
| - target_length (`int`): |
1054 |
| - The target length: when generating with static cache, the mask should be as long as the static cache, |
1055 |
| - to account for the 0 padding, the part of the cache that is not filled yet. |
1056 |
| - dtype (`torch.dtype`): |
1057 |
| - The dtype to use for the 4D attention mask. |
1058 |
| - cache_position (`torch.Tensor`): |
1059 |
| - Indices depicting the position of the input sequence tokens in the sequence. |
1060 |
| - batch_size (`torch.Tensor`): |
1061 |
| - Batch size. |
1062 |
| - """ |
1063 |
| - if attention_mask is not None and attention_mask.dim() == 4: |
1064 |
| - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
1065 |
| - causal_mask = attention_mask |
1066 |
| - else: |
1067 |
| - min_dtype = torch.finfo(dtype).min |
1068 |
| - causal_mask = torch.full( |
1069 |
| - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
1070 |
| - ) |
1071 |
| - if sequence_length != 1: |
1072 |
| - causal_mask = torch.triu(causal_mask, diagonal=1) |
1073 |
| - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
1074 |
| - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
1075 |
| - if attention_mask is not None: |
1076 |
| - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
1077 |
| - mask_length = attention_mask.shape[-1] |
1078 |
| - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
1079 |
| - padding_mask = padding_mask == 0 |
1080 |
| - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
1081 |
| - padding_mask, min_dtype |
1082 |
| - ) |
1083 |
| - |
1084 |
| - return causal_mask |
1085 |
| - |
1086 | 982 |
|
1087 | 983 | @auto_docstring(
|
1088 | 984 | custom_intro="""
|
|
0 commit comments