|
14 | 14 | # limitations under the License.
|
15 | 15 | """ Testing suite for the PyTorch LLaMA model. """
|
16 | 16 |
|
| 17 | +import gc |
17 | 18 | import tempfile
|
18 | 19 | import unittest
|
19 | 20 |
|
@@ -821,3 +822,137 @@ def test_model_7b_logits(self):
|
821 | 822 | ]
|
822 | 823 | infilling = tokenizer.batch_decode(generated_ids)
|
823 | 824 | self.assertEqual(infilling, EXPECTED_INFILLING)
|
| 825 | + |
| 826 | + |
| 827 | +@require_torch_gpu |
| 828 | +class Mask4DTestHard(unittest.TestCase): |
| 829 | + def tearDown(self): |
| 830 | + gc.collect() |
| 831 | + torch.cuda.empty_cache() |
| 832 | + |
| 833 | + def setUp(self): |
| 834 | + model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| 835 | + self.model_dtype = torch.float32 |
| 836 | + self.tokenizer = LlamaTokenizer.from_pretrained(model_name) |
| 837 | + self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) |
| 838 | + |
| 839 | + def get_test_data(self): |
| 840 | + template = "my favorite {}" |
| 841 | + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item |
| 842 | + |
| 843 | + batch_separate = [template.format(x) for x in items] # 3 separate lines |
| 844 | + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated |
| 845 | + |
| 846 | + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) |
| 847 | + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) |
| 848 | + |
| 849 | + mask_shared_prefix = torch.tensor( |
| 850 | + [ |
| 851 | + [ |
| 852 | + [ |
| 853 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 854 | + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 855 | + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 856 | + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], |
| 857 | + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], |
| 858 | + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], |
| 859 | + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], |
| 860 | + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], |
| 861 | + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], |
| 862 | + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], |
| 863 | + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], |
| 864 | + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], |
| 865 | + ] |
| 866 | + ] |
| 867 | + ], |
| 868 | + device=torch_device, |
| 869 | + dtype=torch.int64, |
| 870 | + ) |
| 871 | + |
| 872 | + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) |
| 873 | + # equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) |
| 874 | + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) # same but nicer |
| 875 | + |
| 876 | + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix |
| 877 | + |
| 878 | + def test_stacked_causal_mask(self): |
| 879 | + ( |
| 880 | + input_ids, |
| 881 | + position_ids, |
| 882 | + input_ids_shared_prefix, |
| 883 | + mask_shared_prefix, |
| 884 | + position_ids_shared_prefix, |
| 885 | + ) = self.get_test_data() |
| 886 | + |
| 887 | + # regular batch |
| 888 | + logits = self.model.forward(input_ids, position_ids=position_ids).logits |
| 889 | + logits_last = logits[:, -1, :] # last tokens in each batch line |
| 890 | + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] |
| 891 | + |
| 892 | + # single forward run with 4D custom mask |
| 893 | + logits_shared_prefix = self.model.forward( |
| 894 | + input_ids_shared_prefix, attention_mask=mask_shared_prefix.bool(), position_ids=position_ids_shared_prefix |
| 895 | + ).logits |
| 896 | + logits_shared_prefix_last = logits_shared_prefix[ |
| 897 | + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : |
| 898 | + ] # last three tokens |
| 899 | + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] |
| 900 | + |
| 901 | + self.assertEqual(decoded, decoded_shared_prefix) |
| 902 | + |
| 903 | + def test_partial_stacked_causal_mask(self): |
| 904 | + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention |
| 905 | + # masks |
| 906 | + |
| 907 | + ( |
| 908 | + input_ids, |
| 909 | + position_ids, |
| 910 | + input_ids_shared_prefix, |
| 911 | + mask_shared_prefix, |
| 912 | + position_ids_shared_prefix, |
| 913 | + ) = self.get_test_data() |
| 914 | + |
| 915 | + # regular batch |
| 916 | + logits = self.model.forward(input_ids, position_ids=position_ids).logits |
| 917 | + logits_last = logits[:, -1, :] # last tokens in each batch line |
| 918 | + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] |
| 919 | + |
| 920 | + # 2 forward runs with custom 4D masks |
| 921 | + part_a = 3 # split point |
| 922 | + |
| 923 | + input_1a = input_ids_shared_prefix[:, :part_a] |
| 924 | + position_ids_1a = position_ids_shared_prefix[:, :part_a] |
| 925 | + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] |
| 926 | + |
| 927 | + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a) |
| 928 | + past_key_values_a = outs_1a["past_key_values"] |
| 929 | + |
| 930 | + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) |
| 931 | + input_1b = input_ids_shared_prefix[:, part_a:] |
| 932 | + position_ids_1b = position_ids_shared_prefix[:, part_a:] |
| 933 | + mask_1b = mask_shared_prefix[:, :, part_a:, :] |
| 934 | + outs_1b = self.model.forward( |
| 935 | + input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a |
| 936 | + ) |
| 937 | + decoded_1b = [ |
| 938 | + self.tokenizer.decode(t) |
| 939 | + for t in outs_1b.logits.argmax(-1)[ |
| 940 | + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a |
| 941 | + ] |
| 942 | + ] |
| 943 | + self.assertEqual(decoded, decoded_1b) |
| 944 | + |
| 945 | + # Case 2: we pass a 4D attention mask regarding the full sequence length (i.e. [..., full_len, full_len]) |
| 946 | + input_1c = input_ids_shared_prefix[:, part_a:] |
| 947 | + position_ids_1c = position_ids_shared_prefix[:, part_a:] |
| 948 | + mask_1c = mask_shared_prefix |
| 949 | + outs_1c = self.model.forward( |
| 950 | + input_1c, attention_mask=mask_1c.bool(), position_ids=position_ids_1c, past_key_values=past_key_values_a |
| 951 | + ) |
| 952 | + decoded_1c = [ |
| 953 | + self.tokenizer.decode(t) |
| 954 | + for t in outs_1c.logits.argmax(-1)[ |
| 955 | + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a |
| 956 | + ] |
| 957 | + ] |
| 958 | + self.assertEqual(decoded, decoded_1c) |
0 commit comments