Skip to content

Commit

Permalink
MobileBERT is ExecuTorch compatible (huggingface#34473)
Browse files Browse the repository at this point in the history
Co-authored-by: Guang Yang <guangyang@fb.com>
  • Loading branch information
guangy10 and Guang Yang authored Oct 29, 2024
1 parent 56c45d5 commit 34620e8
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion tests/models/mobilebert/test_modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import unittest

from transformers import MobileBertConfig, is_torch_available
from packaging import version

from transformers import AutoTokenizer, MobileBertConfig, MobileBertForMaskedLM, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device

Expand Down Expand Up @@ -384,3 +386,42 @@ def test_inference_no_head(self):
upper_bound = torch.all((expected_slice / output[..., :3, :3]) <= 1 + TOLERANCE)

self.assertTrue(lower_bound and upper_bound)

@slow
def test_export(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

mobilebert_model = "google/mobilebert-uncased"
device = "cpu"
attn_implementation = "eager"
max_length = 512

tokenizer = AutoTokenizer.from_pretrained(mobilebert_model)
inputs = tokenizer(
f"the man worked as a {tokenizer.mask_token}.",
return_tensors="pt",
padding="max_length",
max_length=max_length,
)

model = MobileBertForMaskedLM.from_pretrained(
mobilebert_model,
device_map=device,
attn_implementation=attn_implementation,
)

logits = model(**inputs).logits
eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "mechanic", "teacher", "clerk"])

exported_program = torch.export.export(
model,
args=(inputs["input_ids"],),
kwargs={"attention_mask": inputs["attention_mask"]},
strict=True,
)

result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
self.assertEqual(eg_predicted_mask, ep_predicted_mask)

0 comments on commit 34620e8

Please sign in to comment.