1
1
import argparse
2
+ import gc
3
+ import logging
4
+ import math
2
5
import subprocess
3
6
import tempfile
4
7
from pathlib import Path
8
+ from typing import List
5
9
6
10
import torch
7
11
from datasets import load_dataset
15
19
)
16
20
from transformers import (
17
21
AutoConfig ,
22
+ AutoModelForCausalLM ,
18
23
AutoModelForImageClassification ,
19
24
AutoProcessor ,
20
25
AutoTokenizer ,
@@ -37,6 +42,56 @@ def cli_export(command, model_dir):
37
42
print (f"Export failed with error: { e } " )
38
43
39
44
45
+ def check_causal_lm_output_quality (
46
+ model_id : str , generated_tokens : List [int ], max_perplexity_threshold : float = 100.0
47
+ ):
48
+ """
49
+ Evaluates the quality of text generated by a causal language model by calculating its perplexity.
50
+
51
+ Args:
52
+ model_id: HuggingFace model identifier (e.g., "google/gemma2-2b")
53
+ generated_tokens: The tokens generated by the exported model to evaluate
54
+ max_perplexity_threshold: Maximum acceptable perplexity (lower is better)
55
+
56
+ Returns:
57
+ tuple: (is_quality_ok, reason) with boolean result and explanation
58
+ """
59
+ logging .info (f"Starting perplexity check with model '{ model_id } ' ..." )
60
+ # Load model
61
+ model = AutoModelForCausalLM .from_pretrained (
62
+ model_id ,
63
+ low_cpu_mem_usage = True ,
64
+ use_cache = False ,
65
+ torch_dtype = torch .bfloat16 ,
66
+ )
67
+
68
+ with torch .no_grad ():
69
+ outputs = model (input_ids = generated_tokens , labels = generated_tokens )
70
+
71
+ # Get the loss (negative log-likelihood)
72
+ loss = outputs .loss .item ()
73
+
74
+ # Calculate perplexity (exp of the average negative log-likelihood)
75
+ perplexity = math .exp (loss )
76
+
77
+ is_quality_ok = perplexity <= max_perplexity_threshold
78
+ if is_quality_ok :
79
+ logging .info (
80
+ f"✓ Perplexity check passed: { perplexity :.2f} <= { max_perplexity_threshold } "
81
+ )
82
+ else :
83
+ logging .warning (
84
+ f"✗ Perplexity check failed: { perplexity :.2f} > { max_perplexity_threshold } "
85
+ )
86
+
87
+ # Clean up immediately
88
+ del model
89
+ del outputs
90
+ gc .collect ()
91
+
92
+ return is_quality_ok
93
+
94
+
40
95
def test_text_generation (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
41
96
command = [
42
97
"optimum-cli" ,
@@ -51,7 +106,15 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
51
106
"--output_dir" ,
52
107
model_dir ,
53
108
]
54
- if "coreml" in recipe :
109
+ if "xnnpack" in recipe :
110
+ if quantize :
111
+ command += [
112
+ "--qlinear" ,
113
+ "8da4w" ,
114
+ "--qembedding" ,
115
+ "8w" ,
116
+ ]
117
+ elif "coreml" in recipe :
55
118
command += [
56
119
"--disable_dynamic_shapes" ,
57
120
]
@@ -63,7 +126,9 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
63
126
"8w" ,
64
127
]
65
128
else :
66
- assert not quantize , "Quantization is not supported for non-CoreML recipes yet"
129
+ assert (
130
+ not quantize
131
+ ), "Quantization is only supported for XnnPack and CoreML recipes at the moment."
67
132
68
133
if not run_only :
69
134
cli_export (command , model_dir )
@@ -77,6 +142,14 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
77
142
max_seq_len = 64 ,
78
143
)
79
144
print (f"\n Generated text:\n \t { generated_text } " )
145
+ generated_tokens = tokenizer (generated_text , return_tensors = "pt" ).input_ids
146
+
147
+ # Free memory before loading eager for quality check
148
+ del model
149
+ del tokenizer
150
+ gc .collect ()
151
+
152
+ assert check_causal_lm_output_quality (model_id , generated_tokens ) is True
80
153
81
154
82
155
def test_fill_mask (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
@@ -278,23 +351,39 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
278
351
)
279
352
args = parser .parse_args ()
280
353
281
- model_to_model_id_and_test_function = {
282
- "smollm" : ("HuggingFaceTB/SmolLM2-135M" , test_text_generation ), # works
283
- "qwen3" : ("Qwen/Qwen3-0.6B" , test_text_generation ), # works
284
- "olmo" : ("allenai/OLMo-1B-hf" , test_text_generation ), # works
285
- "gemma3" : ("unsloth/gemma-3-1b-it" , test_text_generation ), # does not export
286
- "phi4" : (
354
+ _text_generation_mapping = {
355
+ "llama3.2-1b" : ("NousResearch/Llama-3.2-1B" , test_text_generation ),
356
+ "qwen3-0.6b" : ("Qwen/Qwen3-0.6B" , test_text_generation ),
357
+ "qwen3-1.7b" : ("Qwen/Qwen3-1.7B" , test_text_generation ),
358
+ "gemma3-1b" : (
359
+ "unsloth/gemma-3-1b-it" ,
360
+ test_text_generation ,
361
+ ), # does not export for CoreML
362
+ "phi4-mini" : (
287
363
"microsoft/Phi-4-mini-instruct" ,
288
364
test_text_generation ,
289
- ), # fails to lower
290
- "llama3" : ("NousResearch/Llama-3.2-1B" , test_text_generation ), # works
291
- "bert" : ("google-bert/bert-base-uncased" , test_fill_mask ), # works
292
- "roberta" : ("FacebookAI/xlmcl-roberta-base" , test_fill_mask ), # works
293
- "distilbert" : ("distilbert/distilbert-base-uncased" , test_fill_mask ), # works
294
- "whisper" : ("openai/whisper-tiny" , test_whisper ), # works
365
+ ), # fails to lower for CoreML
366
+ "smollm2-135m" : ("HuggingFaceTB/SmolLM2-135M" , test_text_generation ),
367
+ "smollm3-3b" : ("HuggingFaceTB/SmolLM3-3B" , test_text_generation ),
368
+ "olmo" : ("allenai/OLMo-1B-hf" , test_text_generation ),
369
+ }
370
+
371
+ _mask_fill_mapping = {
372
+ "bert" : ("google-bert/bert-base-uncased" , test_fill_mask ),
373
+ "roberta" : ("FacebookAI/xlmcl-roberta-base" , test_fill_mask ),
374
+ "distilbert" : ("distilbert/distilbert-base-uncased" , test_fill_mask ),
375
+ }
376
+
377
+ _misc_model_mapping = {
378
+ "whisper" : ("openai/whisper-tiny" , test_whisper ),
295
379
"t5" : ("google-t5/t5-small" , test_t5 ), # CoreML runime failure
296
- "vit" : ("google/vit-base-patch16-224" , test_vit ), # works
380
+ "vit" : ("google/vit-base-patch16-224" , test_vit ),
297
381
}
382
+
383
+ model_to_model_id_and_test_function = (
384
+ _text_generation_mapping + _mask_fill_mapping + _misc_model_mapping
385
+ )
386
+
298
387
if args .model not in model_to_model_id_and_test_function :
299
388
raise ValueError (
300
389
f"Unknown model name: { args .model } . Available models: { model_to_model_id_and_test_function .keys ()} "
0 commit comments