Skip to content

Commit f7a1e06

Browse files
committed
test moved to mixin
1 parent 57503c0 commit f7a1e06

File tree

7 files changed

+245
-245
lines changed

7 files changed

+245
-245
lines changed

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,9 +1008,9 @@ def _update_causal_mask(
10081008
# we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape
10091009
# as the causal mask (i.e. [..., seq_len, full_len])
10101010
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1011-
if attention_mask.shape[-2] == cache_position[0] + sequence_length:
1012-
offset = cache_position[0]
1013-
mask_slice = mask_slice[..., offset : offset + sequence_length, :]
1011+
offset = cache_position[0]
1012+
if attention_mask.shape[-2] == offset + sequence_length:
1013+
mask_slice = mask_slice[..., offset:, :]
10141014
causal_mask = mask_slice
10151015
else:
10161016
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -994,9 +994,9 @@ def _update_causal_mask(
994994
# we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape
995995
# as the causal mask (i.e. [..., seq_len, full_len])
996996
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
997-
if attention_mask.shape[-2] == cache_position[0] + sequence_length:
998-
offset = cache_position[0]
999-
mask_slice = mask_slice[..., offset : offset + sequence_length, :]
997+
offset = cache_position[0]
998+
if attention_mask.shape[-2] == offset + sequence_length:
999+
mask_slice = mask_slice[..., offset:, :]
10001000
causal_mask = mask_slice
10011001
else:
10021002
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache

src/transformers/models/llama/modeling_llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,9 +1086,9 @@ def _update_causal_mask(
10861086
# we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape
10871087
# as the causal mask (i.e. [..., seq_len, full_len])
10881088
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1089-
if attention_mask.shape[-2] == cache_position[0] + sequence_length:
1090-
offset = cache_position[0]
1091-
mask_slice = mask_slice[..., offset : offset + sequence_length, :]
1089+
offset = cache_position[0]
1090+
if attention_mask.shape[-2] == offset + sequence_length:
1091+
mask_slice = mask_slice[..., offset:, :]
10921092
causal_mask = mask_slice
10931093
else:
10941094
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache

tests/models/llama/test_modeling_llama.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
""" Testing suite for the PyTorch LLaMA model. """
1616

17+
import gc
1718
import tempfile
1819
import unittest
1920

@@ -821,3 +822,137 @@ def test_model_7b_logits(self):
821822
]
822823
infilling = tokenizer.batch_decode(generated_ids)
823824
self.assertEqual(infilling, EXPECTED_INFILLING)
825+
826+
827+
@require_torch_gpu
828+
class Mask4DTestHard(unittest.TestCase):
829+
def tearDown(self):
830+
gc.collect()
831+
torch.cuda.empty_cache()
832+
833+
def setUp(self):
834+
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
835+
self.model_dtype = torch.float32
836+
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
837+
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
838+
839+
def get_test_data(self):
840+
template = "my favorite {}"
841+
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
842+
843+
batch_separate = [template.format(x) for x in items] # 3 separate lines
844+
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
845+
846+
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
847+
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
848+
849+
mask_shared_prefix = torch.tensor(
850+
[
851+
[
852+
[
853+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
854+
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
855+
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
856+
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
857+
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
858+
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
859+
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
860+
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
861+
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
862+
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
863+
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
864+
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
865+
]
866+
]
867+
],
868+
device=torch_device,
869+
dtype=torch.int64,
870+
)
871+
872+
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
873+
# equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
874+
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) # same but nicer
875+
876+
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
877+
878+
def test_stacked_causal_mask(self):
879+
(
880+
input_ids,
881+
position_ids,
882+
input_ids_shared_prefix,
883+
mask_shared_prefix,
884+
position_ids_shared_prefix,
885+
) = self.get_test_data()
886+
887+
# regular batch
888+
logits = self.model.forward(input_ids, position_ids=position_ids).logits
889+
logits_last = logits[:, -1, :] # last tokens in each batch line
890+
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
891+
892+
# single forward run with 4D custom mask
893+
logits_shared_prefix = self.model.forward(
894+
input_ids_shared_prefix, attention_mask=mask_shared_prefix.bool(), position_ids=position_ids_shared_prefix
895+
).logits
896+
logits_shared_prefix_last = logits_shared_prefix[
897+
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
898+
] # last three tokens
899+
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
900+
901+
self.assertEqual(decoded, decoded_shared_prefix)
902+
903+
def test_partial_stacked_causal_mask(self):
904+
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
905+
# masks
906+
907+
(
908+
input_ids,
909+
position_ids,
910+
input_ids_shared_prefix,
911+
mask_shared_prefix,
912+
position_ids_shared_prefix,
913+
) = self.get_test_data()
914+
915+
# regular batch
916+
logits = self.model.forward(input_ids, position_ids=position_ids).logits
917+
logits_last = logits[:, -1, :] # last tokens in each batch line
918+
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
919+
920+
# 2 forward runs with custom 4D masks
921+
part_a = 3 # split point
922+
923+
input_1a = input_ids_shared_prefix[:, :part_a]
924+
position_ids_1a = position_ids_shared_prefix[:, :part_a]
925+
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
926+
927+
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
928+
past_key_values_a = outs_1a["past_key_values"]
929+
930+
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
931+
input_1b = input_ids_shared_prefix[:, part_a:]
932+
position_ids_1b = position_ids_shared_prefix[:, part_a:]
933+
mask_1b = mask_shared_prefix[:, :, part_a:, :]
934+
outs_1b = self.model.forward(
935+
input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
936+
)
937+
decoded_1b = [
938+
self.tokenizer.decode(t)
939+
for t in outs_1b.logits.argmax(-1)[
940+
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
941+
]
942+
]
943+
self.assertEqual(decoded, decoded_1b)
944+
945+
# Case 2: we pass a 4D attention mask regarding the full sequence length (i.e. [..., full_len, full_len])
946+
input_1c = input_ids_shared_prefix[:, part_a:]
947+
position_ids_1c = position_ids_shared_prefix[:, part_a:]
948+
mask_1c = mask_shared_prefix
949+
outs_1c = self.model.forward(
950+
input_1c, attention_mask=mask_1c.bool(), position_ids=position_ids_1c, past_key_values=past_key_values_a
951+
)
952+
decoded_1c = [
953+
self.tokenizer.decode(t)
954+
for t in outs_1c.logits.argmax(-1)[
955+
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
956+
]
957+
]
958+
self.assertEqual(decoded, decoded_1c)

tests/models/mixtral/test_modeling_mixtral.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,12 @@ def test_load_balancing_loss(self):
505505
# This is to mimic torch.testing.assert_not_close
506506
self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item())
507507

508+
# TODO: fix me
509+
@unittest.skip("Test is failing on Mixtral, needs to be fixed")
510+
# Ignore copy
511+
def test_custom_4d_attention_mask_logits(self):
512+
pass
513+
508514

509515
@require_torch
510516
class MixtralIntegrationTest(unittest.TestCase):

tests/test_modeling_common.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4132,6 +4132,101 @@ def test_flash_attn_2_from_config(self):
41324132

41334133
self.assertFalse(fa2_correctly_converted)
41344134

4135+
def _get_custom_4d_mask_test_data(self):
4136+
# Sequence in which all but the last token is the same
4137+
input_ids = torch.tensor(
4138+
[[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
4139+
)
4140+
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
4141+
4142+
# Combining common prefix with the unique ending tokens:
4143+
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
4144+
4145+
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
4146+
mask_shared_prefix = torch.tensor(
4147+
[
4148+
[
4149+
[
4150+
[1, 0, 0, 0, 0, 0],
4151+
[1, 1, 0, 0, 0, 0],
4152+
[1, 1, 1, 0, 0, 0],
4153+
[1, 1, 1, 1, 0, 0],
4154+
[1, 1, 1, 0, 1, 0],
4155+
[1, 1, 1, 0, 0, 1],
4156+
]
4157+
]
4158+
],
4159+
device=torch_device,
4160+
dtype=torch.int64,
4161+
)
4162+
4163+
# Creating a position_ids tensor. note the repeating figures in the end.
4164+
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
4165+
4166+
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
4167+
4168+
def test_custom_4d_attention_mask(self):
4169+
if len(self.all_generative_model_classes) == 0:
4170+
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
4171+
4172+
for model_class in self.all_generative_model_classes:
4173+
if not model_class._supports_cache_class:
4174+
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
4175+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
4176+
model = model_class(config).to(device=torch_device, dtype=torch.float32)
4177+
4178+
(
4179+
input_ids,
4180+
position_ids,
4181+
input_ids_shared_prefix,
4182+
mask_shared_prefix,
4183+
position_ids_shared_prefix,
4184+
) = self._get_custom_4d_mask_test_data()
4185+
causal_mask_shared_prefix = (1 - mask_shared_prefix).to(model.dtype) * torch.finfo(model.dtype).min
4186+
4187+
input_embeds = model.model.embed_tokens(input_ids)
4188+
model_output = model.model.layers[0].self_attn.forward(input_embeds, position_ids=position_ids)[0]
4189+
# model_output.shape == torch.Size([3, 4, ...])
4190+
4191+
input_embeds_shared_prefix = model.model.embed_tokens(input_ids_shared_prefix)
4192+
model_output_shared_prefix = model.model.layers[0].self_attn.forward(
4193+
input_embeds_shared_prefix,
4194+
attention_mask=causal_mask_shared_prefix,
4195+
position_ids=position_ids_shared_prefix,
4196+
)[0]
4197+
# model_output_shared_prefix.shape == torch.Size([1, 6, ...])
4198+
4199+
out_last_tokens = model_output[:, -1, :] # last tokens in each batch line
4200+
out_shared_prefix_last_tokens = model_output_shared_prefix[0, -3:, :] # last three tokens
4201+
torch.testing.assert_close(out_last_tokens, out_shared_prefix_last_tokens)
4202+
4203+
def test_custom_4d_attention_mask_logits(self):
4204+
if len(self.all_generative_model_classes) == 0:
4205+
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
4206+
4207+
for model_class in self.all_generative_model_classes:
4208+
if not model_class._supports_cache_class:
4209+
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
4210+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
4211+
model = model_class(config).to(device=torch_device, dtype=torch.float32)
4212+
4213+
(
4214+
input_ids,
4215+
position_ids,
4216+
input_ids_shared_prefix,
4217+
mask_shared_prefix,
4218+
position_ids_shared_prefix,
4219+
) = self._get_custom_4d_mask_test_data()
4220+
4221+
logits = model.forward(input_ids, position_ids=position_ids).logits
4222+
logits_shared_prefix = model.forward(
4223+
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
4224+
).logits
4225+
4226+
logits_last_tokens = logits[:, -1, :] # last tokens in each batch line
4227+
logits_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
4228+
torch.testing.assert_close(logits_last_tokens, logits_shared_prefix_last_tokens)
4229+
41354230

41364231
global_rng = random.Random()
41374232

0 commit comments

Comments
 (0)