Skip to content

Commit 0788481

Browse files
committed
[BC 4.37 -> 4.38] for Llama family, memory and speed (#29753)
* attempt to fix * the actual fix that works with compilation! * this? * temporary update * nit? * dispatcg to memory efficient? * update both models that have static cache support * fix copies fix compile * make sure fix * fix cohere and gemma * fix beams? * nit * slipped through the cracks * nit * nits * update * fix-copies * skip failing tests * nits
1 parent 74f2900 commit 0788481

File tree

5 files changed

+72
-92
lines changed

5 files changed

+72
-92
lines changed

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,7 @@ def forward(
274274
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
275275

276276
if attention_mask is not None: # no matter the length, we just slice it
277-
causal_mask = attention_mask
278-
if cache_position is not None:
279-
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
277+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
280278
attn_weights = attn_weights + causal_mask
281279

282280
# upcast attention to fp32
@@ -559,8 +557,9 @@ def forward(
559557
value_states = repeat_kv(value_states, self.num_key_value_groups)
560558

561559
causal_mask = attention_mask
562-
if attention_mask is not None and cache_position is not None:
563-
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
560+
# if attention_mask is not None and cache_position is not None:
561+
if attention_mask is not None:
562+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
564563

565564
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
566565
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -692,7 +691,7 @@ class CoherePreTrainedModel(PreTrainedModel):
692691
base_model_prefix = "model"
693692
supports_gradient_checkpointing = True
694693
_no_split_modules = ["CohereDecoderLayer"]
695-
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
694+
_skip_keys_device_placement = ["past_key_values"]
696695
_supports_flash_attn_2 = True
697696
_supports_sdpa = True
698697
_supports_cache_class = True
@@ -715,12 +714,6 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
715714
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
716715
)
717716

718-
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
719-
causal_mask = torch.full(
720-
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
721-
)
722-
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
723-
724717
for layer in self.model.layers:
725718
device = layer.input_layernorm.weight.device
726719
if hasattr(self.config, "_pre_quantization_dtype"):
@@ -899,7 +892,7 @@ def forward(
899892
if position_ids is None:
900893
position_ids = cache_position.unsqueeze(0)
901894

902-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
895+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
903896

904897
# embed positions
905898
hidden_states = inputs_embeds
@@ -967,25 +960,27 @@ def forward(
967960
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
968961
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
969962
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
970-
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
963+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
971964
if self.config._attn_implementation == "flash_attention_2":
972965
if attention_mask is not None and 0.0 in attention_mask:
973966
return attention_mask
974967
return None
975968

976-
batch_size, seq_length = input_tensor.shape[:2]
977-
dtype = input_tensor.dtype
978-
device = input_tensor.device
979-
980-
# support going beyond cached `max_position_embedding`
981-
if seq_length > self.causal_mask.shape[-1]:
982-
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
983-
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
984-
985-
# We use the current dtype to avoid any overflows
969+
dtype, device = input_tensor.dtype, input_tensor.device
986970
min_dtype = torch.finfo(dtype).min
987-
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
988-
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
971+
sequence_length = input_tensor.shape[1]
972+
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
973+
target_length = self.config.max_position_embeddings
974+
else: # dynamic cache
975+
target_length = (
976+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
977+
)
978+
979+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
980+
if sequence_length != 1:
981+
causal_mask = torch.triu(causal_mask, diagonal=1)
982+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
983+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
989984
if attention_mask is not None:
990985
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
991986
if attention_mask.dim() == 2:
@@ -995,8 +990,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
995990
elif attention_mask.dim() == 4:
996991
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
997992
# cache. In that case, the 4D attention mask attends to the newest tokens only.
998-
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
999-
offset = past_seen_tokens
993+
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
994+
offset = cache_position[0]
1000995
else:
1001996
offset = 0
1002997
mask_shape = attention_mask.shape

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,7 @@ def forward(
279279
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
280280

281281
if attention_mask is not None: # no matter the length, we just slice it
282-
if cache_position is not None:
283-
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
284-
else:
285-
causal_mask = attention_mask
282+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
286283
attn_weights = attn_weights + causal_mask
287284

288285
# upcast attention to fp32
@@ -563,8 +560,8 @@ def forward(
563560
value_states = repeat_kv(value_states, self.num_key_value_groups)
564561

565562
causal_mask = attention_mask
566-
if attention_mask is not None and cache_position is not None:
567-
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
563+
if attention_mask is not None:
564+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
568565

569566
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
570567
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -836,12 +833,6 @@ def __init__(self, config: GemmaConfig):
836833
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
837834
self.gradient_checkpointing = False
838835

839-
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
840-
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
841-
causal_mask = torch.full(
842-
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
843-
)
844-
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
845836
# Initialize weights and apply final processing
846837
self.post_init()
847838

@@ -901,7 +892,7 @@ def forward(
901892
if position_ids is None:
902893
position_ids = cache_position.unsqueeze(0)
903894

904-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
895+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
905896

906897
# embed positions
907898
hidden_states = inputs_embeds
@@ -975,26 +966,27 @@ def forward(
975966
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
976967
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
977968
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
978-
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
969+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
979970
if self.config._attn_implementation == "flash_attention_2":
980971
if attention_mask is not None and 0.0 in attention_mask:
981972
return attention_mask
982973
return None
983974

984-
batch_size, seq_length = input_tensor.shape[:2]
985-
dtype = input_tensor.dtype
986-
device = input_tensor.device
987-
988-
# support going beyond cached `max_position_embedding`
989-
if seq_length > self.causal_mask.shape[-1]:
990-
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
991-
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
992-
993-
# We use the current dtype to avoid any overflows
975+
dtype, device = input_tensor.dtype, input_tensor.device
994976
min_dtype = torch.finfo(dtype).min
977+
sequence_length = input_tensor.shape[1]
978+
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
979+
target_length = self.config.max_position_embeddings
980+
else: # dynamic cache
981+
target_length = (
982+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
983+
)
995984

996-
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
997-
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
985+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
986+
if sequence_length != 1:
987+
causal_mask = torch.triu(causal_mask, diagonal=1)
988+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
989+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
998990
if attention_mask is not None:
999991
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1000992
if attention_mask.dim() == 2:
@@ -1004,8 +996,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
1004996
elif attention_mask.dim() == 4:
1005997
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1006998
# cache. In that case, the 4D attention mask attends to the newest tokens only.
1007-
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
1008-
offset = past_seen_tokens
999+
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1000+
offset = cache_position[0]
10091001
else:
10101002
offset = 0
10111003
mask_shape = attention_mask.shape

src/transformers/models/llama/modeling_llama.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,7 @@ def forward(
371371
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
372372

373373
if attention_mask is not None: # no matter the length, we just slice it
374-
causal_mask = attention_mask
375-
if cache_position is not None:
376-
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
374+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
377375
attn_weights = attn_weights + causal_mask
378376

379377
# upcast attention to fp32
@@ -658,8 +656,9 @@ def forward(
658656
value_states = repeat_kv(value_states, self.num_key_value_groups)
659657

660658
causal_mask = attention_mask
661-
if attention_mask is not None and cache_position is not None:
662-
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
659+
# if attention_mask is not None and cache_position is not None:
660+
if attention_mask is not None:
661+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
663662

664663
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
665664
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -792,7 +791,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
792791
base_model_prefix = "model"
793792
supports_gradient_checkpointing = True
794793
_no_split_modules = ["LlamaDecoderLayer"]
795-
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
794+
_skip_keys_device_placement = ["past_key_values"]
796795
_supports_flash_attn_2 = True
797796
_supports_sdpa = True
798797
_supports_cache_class = True
@@ -815,12 +814,6 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
815814
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
816815
)
817816

818-
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
819-
causal_mask = torch.full(
820-
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
821-
)
822-
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
823-
824817
for layer in self.model.layers:
825818
device = layer.input_layernorm.weight.device
826819
if hasattr(self.config, "_pre_quantization_dtype"):
@@ -934,12 +927,6 @@ def __init__(self, config: LlamaConfig):
934927
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
935928
self.gradient_checkpointing = False
936929

937-
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
938-
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
939-
causal_mask = torch.full(
940-
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
941-
)
942-
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
943930
# Initialize weights and apply final processing
944931
self.post_init()
945932

@@ -1000,7 +987,7 @@ def forward(
1000987
if position_ids is None:
1001988
position_ids = cache_position.unsqueeze(0)
1002989

1003-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
990+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
1004991

1005992
# embed positions
1006993
hidden_states = inputs_embeds
@@ -1068,25 +1055,27 @@ def forward(
10681055
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
10691056
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
10701057
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1071-
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
1058+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
10721059
if self.config._attn_implementation == "flash_attention_2":
10731060
if attention_mask is not None and 0.0 in attention_mask:
10741061
return attention_mask
10751062
return None
10761063

1077-
batch_size, seq_length = input_tensor.shape[:2]
1078-
dtype = input_tensor.dtype
1079-
device = input_tensor.device
1080-
1081-
# support going beyond cached `max_position_embedding`
1082-
if seq_length > self.causal_mask.shape[-1]:
1083-
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1084-
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1085-
1086-
# We use the current dtype to avoid any overflows
1064+
dtype, device = input_tensor.dtype, input_tensor.device
10871065
min_dtype = torch.finfo(dtype).min
1088-
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
1089-
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
1066+
sequence_length = input_tensor.shape[1]
1067+
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
1068+
target_length = self.config.max_position_embeddings
1069+
else: # dynamic cache
1070+
target_length = (
1071+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
1072+
)
1073+
1074+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1075+
if sequence_length != 1:
1076+
causal_mask = torch.triu(causal_mask, diagonal=1)
1077+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1078+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
10901079
if attention_mask is not None:
10911080
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
10921081
if attention_mask.dim() == 2:
@@ -1096,8 +1085,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
10961085
elif attention_mask.dim() == 4:
10971086
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
10981087
# cache. In that case, the 4D attention mask attends to the newest tokens only.
1099-
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
1100-
offset = past_seen_tokens
1088+
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1089+
offset = cache_position[0]
11011090
else:
11021091
offset = 0
11031092
mask_shape = attention_mask.shape

tests/models/cohere/test_modeling_cohere.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,9 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
283283
)
284284
test_headmasking = False
285285
test_pruning = False
286-
fx_compatible = True
286+
fx_compatible = (
287+
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
288+
)
287289

288290
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
289291
# This is because we are hitting edge cases with the causal_mask buffer

tests/models/llama/test_modeling_llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
300300
)
301301
test_headmasking = False
302302
test_pruning = False
303-
fx_compatible = True
303+
fx_compatible = (
304+
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
305+
)
304306

305307
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
306308
# This is because we are hitting edge cases with the causal_mask buffer

0 commit comments

Comments
 (0)