Skip to content

Commit 9f6c79e

Browse files
committed
Improve optimum coverage in ET (more models, xnnpack on mac)
1 parent 1976647 commit 9f6c79e

File tree

2 files changed

+222
-79
lines changed

2 files changed

+222
-79
lines changed

.ci/scripts/test_huggingface_optimum_model.py

Lines changed: 104 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import argparse
2+
import gc
3+
import logging
4+
import math
25
import subprocess
36
import tempfile
47
from pathlib import Path
8+
from typing import List
59

610
import torch
711
from datasets import load_dataset
@@ -15,6 +19,7 @@
1519
)
1620
from transformers import (
1721
AutoConfig,
22+
AutoModelForCausalLM,
1823
AutoModelForImageClassification,
1924
AutoProcessor,
2025
AutoTokenizer,
@@ -37,6 +42,56 @@ def cli_export(command, model_dir):
3742
print(f"Export failed with error: {e}")
3843

3944

45+
def check_causal_lm_output_quality(
46+
model_id: str, generated_tokens: List[int], max_perplexity_threshold: float = 100.0
47+
):
48+
"""
49+
Evaluates the quality of text generated by a causal language model by calculating its perplexity.
50+
51+
Args:
52+
model_id: HuggingFace model identifier (e.g., "google/gemma2-2b")
53+
generated_tokens: The tokens generated by the exported model to evaluate
54+
max_perplexity_threshold: Maximum acceptable perplexity (lower is better)
55+
56+
Returns:
57+
tuple: (is_quality_ok, reason) with boolean result and explanation
58+
"""
59+
logging.info(f"Starting perplexity check with model '{model_id}' ...")
60+
# Load model
61+
model = AutoModelForCausalLM.from_pretrained(
62+
model_id,
63+
low_cpu_mem_usage=True,
64+
use_cache=False,
65+
torch_dtype=torch.bfloat16,
66+
)
67+
68+
with torch.no_grad():
69+
outputs = model(input_ids=generated_tokens, labels=generated_tokens)
70+
71+
# Get the loss (negative log-likelihood)
72+
loss = outputs.loss.item()
73+
74+
# Calculate perplexity (exp of the average negative log-likelihood)
75+
perplexity = math.exp(loss)
76+
77+
is_quality_ok = perplexity <= max_perplexity_threshold
78+
if is_quality_ok:
79+
logging.info(
80+
f"✓ Perplexity check passed: {perplexity:.2f} <= {max_perplexity_threshold}"
81+
)
82+
else:
83+
logging.warning(
84+
f"✗ Perplexity check failed: {perplexity:.2f} > {max_perplexity_threshold}"
85+
)
86+
87+
# Clean up immediately
88+
del model
89+
del outputs
90+
gc.collect()
91+
92+
return is_quality_ok
93+
94+
4095
def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only=False):
4196
command = [
4297
"optimum-cli",
@@ -51,7 +106,15 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
51106
"--output_dir",
52107
model_dir,
53108
]
54-
if "coreml" in recipe:
109+
if "xnnpack" in recipe:
110+
if quantize:
111+
command += [
112+
"--qlinear",
113+
"8da4w",
114+
"--qembedding",
115+
"8w",
116+
]
117+
elif "coreml" in recipe:
55118
command += [
56119
"--disable_dynamic_shapes",
57120
]
@@ -63,7 +126,9 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
63126
"8w",
64127
]
65128
else:
66-
assert not quantize, "Quantization is not supported for non-CoreML recipes yet"
129+
assert (
130+
not quantize
131+
), "Quantization is only supported for XnnPack and CoreML recipes at the moment."
67132

68133
if not run_only:
69134
cli_export(command, model_dir)
@@ -77,6 +142,14 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
77142
max_seq_len=64,
78143
)
79144
print(f"\nGenerated text:\n\t{generated_text}")
145+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
146+
147+
# Free memory before loading eager for quality check
148+
del model
149+
del tokenizer
150+
gc.collect()
151+
152+
assert check_causal_lm_output_quality(model_id, generated_tokens) is True
80153

81154

82155
def test_fill_mask(model_id, model_dir, recipe, *, quantize=True, run_only=False):
@@ -278,23 +351,39 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
278351
)
279352
args = parser.parse_args()
280353

281-
model_to_model_id_and_test_function = {
282-
"smollm": ("HuggingFaceTB/SmolLM2-135M", test_text_generation), # works
283-
"qwen3": ("Qwen/Qwen3-0.6B", test_text_generation), # works
284-
"olmo": ("allenai/OLMo-1B-hf", test_text_generation), # works
285-
"gemma3": ("unsloth/gemma-3-1b-it", test_text_generation), # does not export
286-
"phi4": (
354+
_text_generation_mapping = {
355+
"llama3.2-1b": ("NousResearch/Llama-3.2-1B", test_text_generation),
356+
"qwen3-0.6b": ("Qwen/Qwen3-0.6B", test_text_generation),
357+
"qwen3-1.7b": ("Qwen/Qwen3-1.7B", test_text_generation),
358+
"gemma3-1b": (
359+
"unsloth/gemma-3-1b-it",
360+
test_text_generation,
361+
), # does not export for CoreML
362+
"phi4-mini": (
287363
"microsoft/Phi-4-mini-instruct",
288364
test_text_generation,
289-
), # fails to lower
290-
"llama3": ("NousResearch/Llama-3.2-1B", test_text_generation), # works
291-
"bert": ("google-bert/bert-base-uncased", test_fill_mask), # works
292-
"roberta": ("FacebookAI/xlmcl-roberta-base", test_fill_mask), # works
293-
"distilbert": ("distilbert/distilbert-base-uncased", test_fill_mask), # works
294-
"whisper": ("openai/whisper-tiny", test_whisper), # works
365+
), # fails to lower for CoreML
366+
"smollm2-135m": ("HuggingFaceTB/SmolLM2-135M", test_text_generation),
367+
"smollm3-3b": ("HuggingFaceTB/SmolLM3-3B", test_text_generation),
368+
"olmo": ("allenai/OLMo-1B-hf", test_text_generation),
369+
}
370+
371+
_mask_fill_mapping = {
372+
"bert": ("google-bert/bert-base-uncased", test_fill_mask),
373+
"roberta": ("FacebookAI/xlmcl-roberta-base", test_fill_mask),
374+
"distilbert": ("distilbert/distilbert-base-uncased", test_fill_mask),
375+
}
376+
377+
_misc_model_mapping = {
378+
"whisper": ("openai/whisper-tiny", test_whisper),
295379
"t5": ("google-t5/t5-small", test_t5), # CoreML runime failure
296-
"vit": ("google/vit-base-patch16-224", test_vit), # works
380+
"vit": ("google/vit-base-patch16-224", test_vit),
297381
}
382+
383+
model_to_model_id_and_test_function = (
384+
_text_generation_mapping + _mask_fill_mapping + _misc_model_mapping
385+
)
386+
298387
if args.model not in model_to_model_id_and_test_function:
299388
raise ValueError(
300389
f"Unknown model name: {args.model}. Available models: {model_to_model_id_and_test_function.keys()}"

0 commit comments

Comments
 (0)