Skip to content

Adding support for integrating arbitrary ModelProto in onnxscript #2616

@chapman73

Description

@chapman73

Hey team,

I'm looking to submit my attempt at resolving this issue: #1882

I'm opening this as an issue first, because I think the solution may be quite large/with various design decisions which (IMO) non-trivial. Though also happy to start with a large PR (so theres a bird's eye view) and then determine which parts are appropriate/not appropriate.


What (more or less works) -- Updating Docs? Adding a util function?

The docs more or less address how this is achieved here: https://microsoft.github.io/onnxscript/auto_examples/06_plot_model_local_funs.html

example code that works:

import numpy as np
import onnxruntime
import pandas as pd
from onnxscript import FLOAT, script
from onnxscript import opset17 as op
from onnxscript.values import Opset
from skl2onnx import to_onnx
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pandas_dataset = pd.DataFrame(
    {
        "age": np.random.randint(18, 60, 100),
        "height": np.random.rand(100) * 100 + 100,
        "weight": np.random.rand(100) * 100 + 40,
    }
)

multi_classifier_target = pd.Series(np.random.choice([0, 1, 2, 3], 100))
# Create the pipeline
pipe = Pipeline([("scaler", StandardScaler()), ("linear", LogisticRegression())])

# Fit the pipeline to the data
pipe.fit(pandas_dataset, multi_classifier_target)

# local = CustomOpset("local", 1)  # see below in the next section
local = Opset("local", 1)

X = pandas_dataset.copy()
scaler_step = pipe.steps[0][1]
linear_step = pipe.steps[1][1]

scaler_model = to_onnx(scaler_step, X.values.astype(np.float32), target_opset={"ai.onnx.ml": 2, "": 17})
linear_model = to_onnx(
    linear_step, scaler_step.transform(X.values.astype(np.float32)), options={"zipmap": False}, target_opset={"ai.onnx.ml": 2, "": 17}
)


# custom func that convers model proto to function proto -- convert_model_proto_to_function_proto
scaler_function_proto = convert_model_proto_to_function_proto(scaler_model, "local", "scaler")
linear_function_proto = convert_model_proto_to_function_proto(linear_model, "local", "linear")


@script(local, default_opset=op)
def sklearn_pipeline(X: FLOAT["N", "D"]) -> FLOAT["N"]:  # noqa: F821, D103
    # script has limitations like doing loops
    X = local.scaler(X)
    # outputs are labels, probabilities
    _, y = local.linear(X)
    return y[:, 1]


model = sklearn_pipeline.to_model_proto(functions=[scaler_function_proto, linear_function_proto])

session = onnxruntime.InferenceSession(model.SerializeToString())
result = session.run(None, {"X": X.values.astype(np.float32)})

print(result)

This works!

Q: does convert_model_proto_to_function_proto belong in onnxscript project? And if so where? (onnxscript/utils or onnxscript/contrib?)

It turns out that its slightly non-trivial to convert model proto to function proto if the model proto has initializers, and having a helper function to facilitate this would help users (if/when docs are added)


What doesn't work out of the box

Evaluating it eagerly does not work. This is because the (arbitrary) Function Proto are not 'attached' to the custom Opset.

I'm eager to see how you would like to see this implemented. I tackled this (leaving out lots of code for now):

class CustomOpset(Opset):
    def add_model_proto(self, name: str, model_proto: onnx.ModelProto):
        # very, very rough code...
        function_proto = convert_model_proto_to_function_proto(model_proto, self.domain, name)
        # magically create onnxfunction somehow
        onnx_function = OnnxFunction(
            opset=self,
            pyfun=dummy_function,
            irfun=ir_builder_function,
            source="",  # No source available for FunctionProto
            kwargs={},
        )
        # register it
        self.function_defs[ir_function.name] = onnx_function

Though perhaps the team has other preference for how this might be implemented (somewhere in evaluator)?


Looking forward to contributing and hearing from the team!

(n.b. waiting for my company to review and sign CLA and then I will put up the relevant code snippets).

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions