Skip to content

Distribute and complete onnxruntime tests (decoder models) #2278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2cc29a0
added test_decoders.py
IlyasMoutawwakil May 23, 2025
4da3df2
fix position ids for single batch and more complete decoder testing f…
IlyasMoutawwakil May 23, 2025
d208f6b
support merging seq2seq models when used as decoders and add more tests
IlyasMoutawwakil May 23, 2025
608da13
fix pipe tests
IlyasMoutawwakil May 23, 2025
f34f6e7
update phi min transformers version (broken by cache position refacto…
IlyasMoutawwakil May 23, 2025
2e1a700
remove deprecated bloom modeling
IlyasMoutawwakil May 23, 2025
00ec0c7
update opt onnx config to the one with position ids
IlyasMoutawwakil May 23, 2025
88dc4a8
remove all complex deprecated modeling
IlyasMoutawwakil May 23, 2025
5f9419e
get_supported_model_type_for_task should only return suooprted model …
IlyasMoutawwakil May 23, 2025
478fd57
update min transformers
IlyasMoutawwakil May 24, 2025
6aa3a17
use transformers like api for use_cache and add can_use_cache and is_…
IlyasMoutawwakil May 24, 2025
7da7015
testing
IlyasMoutawwakil May 24, 2025
f9f7395
fix
IlyasMoutawwakil May 25, 2025
8785be6
fix
IlyasMoutawwakil May 25, 2025
5f81515
remove unnecessary
IlyasMoutawwakil May 25, 2025
6e3bff1
simply qwen3
IlyasMoutawwakil May 25, 2025
aacf172
docs
IlyasMoutawwakil May 25, 2025
2b0137f
qwen-moe
IlyasMoutawwakil May 25, 2025
7041a89
model type shenanigans
IlyasMoutawwakil May 25, 2025
de244d5
fix
IlyasMoutawwakil May 25, 2025
088b265
use test models from optimum-internal-hf with proper metadata
IlyasMoutawwakil May 26, 2025
af7c6bb
Update optimum/onnxruntime/modeling_decoder.py
IlyasMoutawwakil May 27, 2025
0149349
keep supported model types
IlyasMoutawwakil May 27, 2025
2d9d7ea
Merge branch 'distribute-tests' of https://github.com/huggingface/opt…
IlyasMoutawwakil May 27, 2025
fbcf3a1
Merge branch 'main' into distribute-tests
IlyasMoutawwakil May 27, 2025
2cf507c
optimum model
IlyasMoutawwakil May 27, 2025
8821d97
fix failing test by forcing export
IlyasMoutawwakil May 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ jobs:
matrix:
python-version: [3.9]
runs-on: [ubuntu-22.04]
test_file: [
test_file:
[
test_timm.py,
test_modeling.py, # todo: split into test_encoder, test_decoder and test_encoder_decoder
test_decoder.py,
test_modeling.py,
test_diffusion.py,
test_optimization.py,
test_quantization.py,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- PoolFormer
- PVT
- Qwen2(Qwen1.5)
- Qwen3
- Qwen3-MoE
- RegNet
- RemBERT
- ResNet
Expand Down
18 changes: 9 additions & 9 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,19 +938,19 @@ def post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task
if len(onnx_files_subpaths) >= 3 and self.use_past is True:
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
# Attempt to merge only if the decoder was exported without/with past
onnx_decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx")
onnx_decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
if onnx_decoder_path.is_file() and onnx_decoder_with_past_path.is_file() and self.use_past is True:
try:
from ...onnx import merge_decoders

# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
from ...onnx import merge_decoders

merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
decoder=onnx_decoder_path,
decoder_with_past=onnx_decoder_with_past_path,
save_path=decoder_merged_path,
strict=False,
)
Expand Down
41 changes: 13 additions & 28 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):


# OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46
if is_transformers_version(">=", "4.45.99"):
if is_transformers_version(">=", "4.46.0"):

class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
Expand All @@ -352,7 +352,6 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):

class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Llama now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
Expand All @@ -371,6 +370,14 @@ class Qwen2OnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.37.0")


class Qwen3OnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.51.0")


class Qwen3MoeOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.51.0")


class GemmaOnnxConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
Expand All @@ -385,7 +392,7 @@ class GraniteOnnxConfig(LlamaOnnxConfig):
class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")
MIN_TRANSFORMERS_VERSION = version.parse("4.42.0")


class Phi3OnnxConfig(PhiOnnxConfig):
Expand Down Expand Up @@ -430,33 +437,11 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
BloomDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES

DEFAULT_ONNX_OPSET = 14 # Bloom uses F.scaled_dot_product_attention
MIN_TRANSFORMERS_VERSION = version.parse("4.44.0")
DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if is_transformers_version(">=", "4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@
"phi",
"phi3",
"qwen2",
"qwen3",
"qwen3-moe",
"granite",
}


if is_transformers_version(">=", "4.45.99"):
if is_transformers_version(">=", "4.46.0"):
MODEL_TYPES_REQUIRING_POSITION_IDS.add("opt")


Expand Down
22 changes: 21 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,23 @@ class TasksManager:
"token-classification",
onnx="Qwen2OnnxConfig",
),
"qwen3": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="Qwen3OnnxConfig",
),
"qwen3-moe": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
"token-classification",
onnx="Qwen3MoeOnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -1475,12 +1492,15 @@ def get_supported_model_type_for_task(task: str, exporter: str) -> List[str]:
"""
Returns the list of supported architectures by the exporter for a given task. Transformers-specific.
"""
return [

supported_model_types = [
model_type.replace("-", "_")
for model_type in TasksManager._SUPPORTED_MODEL_TYPE
if task in TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter]
]

return supported_model_types

@staticmethod
def synonyms_for_task(task: str) -> Set[str]:
synonyms = [k for k, v in TasksManager._SYNONYM_TASK_MAP.items() if v == task]
Expand Down
Loading