Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions onnxscript/utils/model_proto_to_function_proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import onnx
import onnx_ir
from onnx import helper


def _initializers_to_constants(model: onnx.ModelProto) -> onnx.ModelProto:
graph = model.graph
new_nodes = []

# Keep track of names to remove from inputs
init_names = {init.name for init in graph.initializer}

for init in graph.initializer:
# Convert initializer to Constant node
const_node = helper.make_node(
"Constant",
inputs=[],
outputs=[init.name],
value=init, # Directly use TensorProto
)
new_nodes.append(const_node)

# Filter out initializer names from graph inputs
filtered_inputs = [i for i in graph.input if i.name not in init_names]
graph.ClearField("input")
graph.input.extend(filtered_inputs)

# Add new Constant nodes at the beginning
all_nodes = new_nodes + list(graph.node)
graph.ClearField("node")
graph.node.extend(all_nodes)

# Clear initializers (since we replaced them)
graph.ClearField("initializer")

return model


def convert_model_proto_to_function_proto(
model: onnx.ModelProto, domain, name
) -> onnx.FunctionProto:
"""Converts an arbitrary ModelProto to a FunctionProto.

Since function protos don't support initializers (or rather it does not make sense in the context of a function)
we need to convert them to constants first.
"""
model = _initializers_to_constants(
model
) # theres some work to do here...maybe contribute to open source?
model_ir = onnx_ir.serde.deserialize_model(model)
function_ir = onnx_ir.Function(
domain=domain, name=name, graph=model_ir.graph, attributes={}
)
return onnx_ir.to_proto(function_ir)
2 changes: 2 additions & 0 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
70 changes: 70 additions & 0 deletions tests/utils/model_proto_to_function_proto_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest

import numpy as np
import onnxruntime as ort

from onnxscript import script
from onnxscript.onnx_opset import opset15 as op
from onnxscript.onnx_types import FLOAT
from onnxscript.utils.model_proto_to_function_proto import (
convert_model_proto_to_function_proto,
)
from onnxscript.values import Opset


class TestModelProtoToFunctionProto(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
# Create a fresh custom opset for each test
self.local = Opset("local", 1)

# Define test functions
@script(self.local, default_opset=op)
def diff_square(x, y):
diff = x - y
return diff * diff

@script(self.local)
def sum_func(z):
return op.ReduceSum(z, keepdims=1)

@script()
def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821
return op.Sqrt(sum_func(diff_square(x, y)))

@script()
def l2norm_with_functions(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821
return op.Sqrt(sum_func(diff_square(x, y)))

self.diff_square = diff_square
self.sum_func = sum_func
self.l2norm = l2norm
self.l2norm_with_functions = l2norm_with_functions

def test_multiple_functions_in_model_proto(self):
"""Test that multiple functions can be included in a single model proto."""
# Add sum function to opset
sum_model = self.sum_func.to_model_proto()
sum_function_proto = convert_model_proto_to_function_proto(
sum_model, "local", "sum_func"
)

model = self.l2norm_with_functions.to_model_proto(
functions=[sum_function_proto, self.diff_square]
)

# Test execution
session = ort.InferenceSession(model.SerializeToString())
result = session.run(
None,
{
"x": np.array([1.0, 2.0, 3.0]).astype(np.float32),
"y": np.array([4.0, 5.0, 6.0]).astype(np.float32),
},
)

# Verify result
self.assertEqual(len(result), 1)
self.assertAlmostEqual(np.sqrt(27.0), result[0][0], places=5) # L2 norm of [3, 3, 3]