From 34620e8f0a974761debf52093968107c14f41315 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Tue, 29 Oct 2024 08:14:31 -0700 Subject: [PATCH] MobileBERT is ExecuTorch compatible (#34473) Co-authored-by: Guang Yang --- .../mobilebert/test_modeling_mobilebert.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/models/mobilebert/test_modeling_mobilebert.py b/tests/models/mobilebert/test_modeling_mobilebert.py index d7a409427c9c51..d2bc11d09f1797 100644 --- a/tests/models/mobilebert/test_modeling_mobilebert.py +++ b/tests/models/mobilebert/test_modeling_mobilebert.py @@ -16,7 +16,9 @@ import unittest -from transformers import MobileBertConfig, is_torch_available +from packaging import version + +from transformers import AutoTokenizer, MobileBertConfig, MobileBertForMaskedLM, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device @@ -384,3 +386,42 @@ def test_inference_no_head(self): upper_bound = torch.all((expected_slice / output[..., :3, :3]) <= 1 + TOLERANCE) self.assertTrue(lower_bound and upper_bound) + + @slow + def test_export(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + mobilebert_model = "google/mobilebert-uncased" + device = "cpu" + attn_implementation = "eager" + max_length = 512 + + tokenizer = AutoTokenizer.from_pretrained(mobilebert_model) + inputs = tokenizer( + f"the man worked as a {tokenizer.mask_token}.", + return_tensors="pt", + padding="max_length", + max_length=max_length, + ) + + model = MobileBertForMaskedLM.from_pretrained( + mobilebert_model, + device_map=device, + attn_implementation=attn_implementation, + ) + + logits = model(**inputs).logits + eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "mechanic", "teacher", "clerk"]) + + exported_program = torch.export.export( + model, + args=(inputs["input_ids"],), + kwargs={"attention_mask": inputs["attention_mask"]}, + strict=True, + ) + + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) + ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask, ep_predicted_mask)