Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 108 additions & 15 deletions .ci/scripts/test_huggingface_optimum_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import argparse
import gc
import logging
import math
import subprocess
import tempfile
from pathlib import Path
from typing import List

import torch
from datasets import load_dataset
Expand All @@ -15,6 +19,7 @@
)
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoProcessor,
AutoTokenizer,
Expand All @@ -37,6 +42,56 @@ def cli_export(command, model_dir):
print(f"Export failed with error: {e}")


def check_causal_lm_output_quality(
model_id: str, generated_tokens: List[int], max_perplexity_threshold: float = 100.0
):
"""
Evaluates the quality of text generated by a causal language model by calculating its perplexity.

Args:
model_id: HuggingFace model identifier (e.g., "google/gemma2-2b")
generated_tokens: The tokens generated by the exported model to evaluate
max_perplexity_threshold: Maximum acceptable perplexity (lower is better)

Returns:
tuple: (is_quality_ok, reason) with boolean result and explanation
"""
logging.info(f"Starting perplexity check with model '{model_id}' ...")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_cache=False,
torch_dtype=torch.bfloat16,
)

with torch.no_grad():
outputs = model(input_ids=generated_tokens, labels=generated_tokens)

# Get the loss (negative log-likelihood)
loss = outputs.loss.item()

# Calculate perplexity (exp of the average negative log-likelihood)
perplexity = math.exp(loss)

is_quality_ok = perplexity <= max_perplexity_threshold
if is_quality_ok:
logging.info(
f"✓ Perplexity check passed: {perplexity:.2f} <= {max_perplexity_threshold}"
)
else:
logging.warning(
f"✗ Perplexity check failed: {perplexity:.2f} > {max_perplexity_threshold}"
)

# Clean up immediately
del model
del outputs
gc.collect()

return is_quality_ok


def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only=False):
command = [
"optimum-cli",
Expand All @@ -51,7 +106,19 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
"--output_dir",
model_dir,
]
if "coreml" in recipe:
if "xnnpack" in recipe:
command += [
"--use_custom_sdpa",
"--use_custom_kv_cache",
]
if quantize:
command += [
"--qlinear",
"8da4w",
"--qembedding",
"8w",
]
elif "coreml" in recipe:
command += [
"--disable_dynamic_shapes",
]
Expand All @@ -63,7 +130,9 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
"8w",
]
else:
assert not quantize, "Quantization is not supported for non-CoreML recipes yet"
assert (
not quantize
), "Quantization is only supported for XnnPack and CoreML recipes at the moment."

if not run_only:
cli_export(command, model_dir)
Expand All @@ -77,6 +146,14 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
max_seq_len=64,
)
print(f"\nGenerated text:\n\t{generated_text}")
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

# Free memory before loading eager for quality check
del model
del tokenizer
gc.collect()

assert check_causal_lm_output_quality(model_id, generated_tokens) is True


def test_fill_mask(model_id, model_dir, recipe, *, quantize=True, run_only=False):
Expand Down Expand Up @@ -278,23 +355,39 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
)
args = parser.parse_args()

model_to_model_id_and_test_function = {
"smollm": ("HuggingFaceTB/SmolLM2-135M", test_text_generation), # works
"qwen3": ("Qwen/Qwen3-0.6B", test_text_generation), # works
"olmo": ("allenai/OLMo-1B-hf", test_text_generation), # works
"gemma3": ("unsloth/gemma-3-1b-it", test_text_generation), # does not export
"phi4": (
_text_generation_mapping = {
"llama3.2-1b": ("NousResearch/Llama-3.2-1B", test_text_generation),
"qwen3-0.6b": ("Qwen/Qwen3-0.6B", test_text_generation),
"qwen3-1.7b": ("Qwen/Qwen3-1.7B", test_text_generation),
"gemma3-1b": (
"unsloth/gemma-3-1b-it",
test_text_generation,
), # does not export for CoreML
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, it does not export for portable/xnnpack either unless you swap the SDPA with the custom one.

So I think the export issue is general, not CoreML specific, but using custom SDPA "fixes" the export issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does CoreML use custom sdpa / custom kv cache update?

"phi4-mini": (
"microsoft/Phi-4-mini-instruct",
test_text_generation,
), # fails to lower
"llama3": ("NousResearch/Llama-3.2-1B", test_text_generation), # works
"bert": ("google-bert/bert-base-uncased", test_fill_mask), # works
"roberta": ("FacebookAI/xlmcl-roberta-base", test_fill_mask), # works
"distilbert": ("distilbert/distilbert-base-uncased", test_fill_mask), # works
"whisper": ("openai/whisper-tiny", test_whisper), # works
), # fails to lower for CoreML
"smollm2-135m": ("HuggingFaceTB/SmolLM2-135M", test_text_generation),
"smollm3-3b": ("HuggingFaceTB/SmolLM3-3B", test_text_generation),
"olmo": ("allenai/OLMo-1B-hf", test_text_generation),
}

_mask_fill_mapping = {
"bert": ("google-bert/bert-base-uncased", test_fill_mask),
"roberta": ("FacebookAI/xlmcl-roberta-base", test_fill_mask),
"distilbert": ("distilbert/distilbert-base-uncased", test_fill_mask),
}

_misc_model_mapping = {
"whisper": ("openai/whisper-tiny", test_whisper),
"t5": ("google-t5/t5-small", test_t5), # CoreML runime failure
"vit": ("google/vit-base-patch16-224", test_vit), # works
"vit": ("google/vit-base-patch16-224", test_vit),
}

model_to_model_id_and_test_function = (
_text_generation_mapping | _mask_fill_mapping | _misc_model_mapping
)

if args.model not in model_to_model_id_and_test_function:
raise ValueError(
f"Unknown model name: {args.model}. Available models: {model_to_model_id_and_test_function.keys()}"
Expand Down
101 changes: 40 additions & 61 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -732,23 +732,26 @@ jobs:
echo "::endgroup::"
done

test-huggingface-transformers:
test-huggingface-transformers-xnnpack:
# NB: Don't run this on fork PRs because they won't have access to the secret and would fail anyway
if: ${{ !github.event.pull_request.head.repo.fork }}
name: test-huggingface-transformers
name: test-huggingface-transformers-xnnpack
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
secrets: inherit
strategy:
matrix:
hf_model_id: [
google/gemma-3-1b-it,
Qwen/Qwen3-0.6B,
HuggingFaceTB/SmolLM2-135M,
meta-llama/Llama-3.2-1B,
allenai/OLMo-1B-hf,
config: [
# XNNPack.
llama3.2-1b|xnnpack|--quantize,
qwen3-0.6b|xnnpack|--quantize,
qwen3-1.7b|xnnpack|--quantize,
gemma3-1b|xnnpack|--quantize,
phi4-mini|xnnpack|--quantize,
smollm2-135m|xnnpack|--quantize,
smollm3-3b|xnnpack|--quantize
]
fail-fast: false
with:
Expand All @@ -760,6 +763,12 @@ jobs:
timeout: 90
upload-artifact: profiling-artifacts-${{ strategy.job-index }}
script: |
set -eux
IFS='|' read -r MODEL RECIPE QUANTIZE <<< "${{ matrix.config }}"
echo "Model: $MODEL"
echo "Recipe: $RECIPE"
echo "Quantize: $QUANTIZE"

echo "::group::Set up ExecuTorch"
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
Expand Down Expand Up @@ -797,82 +806,52 @@ jobs:
pip list
echo "::endgroup::"

echo "::group::Export to ExecuTorch"
# Pass matrix variable as environment variable
export MODEL_ID="${{ matrix.hf_model_id }}"
export OUTPUT_DIR="$(pwd)/${MODEL_ID}_custom_sdpa_kv_cache_8da4w"
pushd optimum-executorch

ARGS=(
"--model" "${MODEL_ID}"
"--task" "text-generation"
"--recipe" "xnnpack"
"--use_custom_sdpa"
"--use_custom_kv_cache"
"--qlinear" "8da4w"
"--qembedding" "8w"
"--output_dir" "${OUTPUT_DIR}"
)

optimum-cli export executorch "${ARGS[@]}"

ls -FlAGhp ${OUTPUT_DIR}
popd
echo "::endgroup::"

echo "::group::Inference using python API"
pushd optimum-executorch
python -c "
import os
from optimum.executorch import ExecuTorchModelForCausalLM
from transformers import AutoTokenizer

model_id = os.getenv('MODEL_ID')
pte_dir = os.getenv('OUTPUT_DIR')
print(f'Loading model {model_id} from {pte_dir}.')
model = ExecuTorchModelForCausalLM.from_pretrained(pte_dir)
generated_text = model.text_generation(
tokenizer=AutoTokenizer.from_pretrained(model_id),
prompt='Simply put, the theory of relativity states that',
max_seq_len=64
)
print(generated_text)
"
popd
echo "::group::Run tests"
export OUTPUT_DIR="$(pwd)/${MODEL}_${RECIPE}_${QUANTIZE}"
python .ci/scripts/test_huggingface_optimum_model.py --model ${MODEL} --recipe ${RECIPE} ${QUANTIZE} --model_dir ${OUTPUT_DIR}
echo "::endgroup::"

echo "::group::Inference using executor_runner with ETDump"
echo "::group::Generate artifacts for performance profiling"
./cmake-out/executor_runner \
--model_path ${OUTPUT_DIR}/model.pte \
--etdump_path ${OUTPUT_DIR}/etdump.etdp

export TSV_PATH=artifacts-to-be-uploaded/${MODEL_ID}_op_prof.tsv
export TSV_PATH=artifacts-to-be-uploaded/${MODEL}_op_prof.tsv
mkdir -p $(dirname "$TSV_PATH")
python3 -m devtools.inspector.inspector_cli \
--etdump_path ${OUTPUT_DIR}/etdump.etdp \
--tsv_path ${TSV_PATH}

echo "::endgroup::"

test-huggingface-optimum-coreml:
test-huggingface-transformers-coreml:
# NB: Don't run this on fork PRs because they won't have access to the secret and would fail anyway
if: ${{ !github.event.pull_request.head.repo.fork }}
name: test-huggingface-optimum-coreml
name: test-huggingface-transformers-coreml
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
permissions:
id-token: write
contents: read
secrets: inherit
# Models below selected based on https://huggingface.co/models?pipeline_tag=text-generation&num_parameters=min:0,max:3B&sort=trending.
strategy:
matrix:
config: [
qwen3|coreml_fp32_gpu|--quantize,
smollm|coreml_fp32_gpu|--quantize,
llama3|coreml_fp32_gpu|--quantize,
olmo|coreml_fp32_gpu|--quantize,
# roberta|coreml_fp32_gpu|--quantize, roberta requires special HF access
# XNNPack.
llama3.2-1b|xnnpack|--quantize,
qwen3-0.6b|xnnpack|--quantize,
qwen3-1.7b|xnnpack|--quantize,
gemma3-1b|xnnpack|--quantize,
phi4-mini|xnnpack|--quantize,
smollm2-135m|xnnpack|--quantize,
smollm3-3b|xnnpack|--quantize,
# CoreML.
llama3.2-1b|coreml_fp32_gpu|--quantize,
qwen3-0.6b|coreml_fp32_gpu|--quantize,
qwen3-1.7b|xnnpack|--quantize,
smollm2-135m|coreml_fp32_gpu|--quantize,
olmo-1b|coreml_fp32_gpu|--quantize,
bert|coreml_fp32_gpu|--quantize,
distilbert|coreml_fp32_gpu|--quantize,
distilbert|coreml_fp32_gpu|--quantize
]
fail-fast: false
with:
Expand Down
Loading