Skip to content

Commit

Permalink
Added support for ONNX flavor (mlflow#1127)
Browse files Browse the repository at this point in the history
* Added support for ONNX flavor

* Tweaks to conform with other flavor conventions

* Lint

* Add flavor description to Models Docs

* Experimental tags

* Mark experimental and docs tweaks

* Revert R changes

* Revert spacing changes to Travis yaml

* Remove unused arg from get_conda_env

* Add test for casting evaluation behavior

* Lint

* Run tests on xenial to see if onnx import errors remain

* Revert travis to trusty and mark onnxruntime-dependent tests as release tests
  • Loading branch information
avflor authored and dbczumar committed Jun 3, 2019
1 parent d3cb5c9 commit 1d0d3e6
Show file tree
Hide file tree
Showing 7 changed files with 652 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ matrix:
- pip install -r travis/small-requirements.txt
- pip install -e .
script:
- pytest --verbose --ignore=tests/h2o --ignore=tests/keras --ignore=tests/pytorch --ignore=tests/pyfunc --ignore=tests/sagemaker --ignore=tests/sklearn --ignore=tests/spark --ignore=tests/tensorflow --ignore tests/azureml --ignore tests/projects tests
- pytest --verbose --ignore=tests/h2o --ignore=tests/keras --ignore=tests/pytorch --ignore=tests/pyfunc --ignore=tests/sagemaker --ignore=tests/sklearn --ignore=tests/spark --ignore=tests/tensorflow --ignore tests/azureml --ignore tests/onnx --ignore tests/projects tests
- language: r
name: "R"
cache: packages
Expand Down
13 changes: 13 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,19 @@ flavor as TensorFlow graphs.

For more information, see :py:mod:`mlflow.tensorflow`.

ONNX (``onnx``)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
The ``onnx`` model flavor enables logging of `ONNX models <http://onnx.ai/>`_ in MLflow format via
the :py:func:`mlflow.onnx.save_model()` and :py:func:`mlflow.onnx.log_model()` methods. These
methods also add the ``python_function`` flavor to the MLflow Models that they produce, allowing the
models to be interpreted as generic Python functions for inference via
:py:func:`mlflow.pyfunc.load_pyfunc()`. The ``python_function`` representation of an MLflow
ONNX model uses the `ONNX Runtime execution engine <https://github.com/microsoft/onnxruntime>`_ for
evaluation Finally, you can use the :py:func:`mlflow.onnx.load_model()` method to load MLflow
Models with the ``onnx`` flavor in native ONNX format.

For more information, see :py:mod:`mlflow.onnx` and `<http://onnx.ai/>`_.

Model Customization
-------------------

Expand Down
227 changes: 227 additions & 0 deletions mlflow/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
from __future__ import absolute_import

import os
import yaml
import numpy as np

import pandas as pd

from mlflow import pyfunc
from mlflow.models import Model
import mlflow.tracking
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils import experimental
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration

FLAVOR_NAME = "onnx"


@experimental
def get_default_conda_env():
"""
:return: The default Conda environment for MLflow Models produced by calls to
:func:`save_model()` and :func:`log_model()`.
"""
import onnx
import onnxruntime
return _mlflow_conda_env(
additional_conda_deps=None,
additional_pip_deps=[
"onnx=={}".format(onnx.__version__),
# The ONNX pyfunc representation requires the OnnxRuntime
# inference engine. Therefore, the conda environment must
# include OnnxRuntime
"onnxruntime=={}".format(onnxruntime.__version__),
],
additional_conda_channels=None,
)


@experimental
def save_model(onnx_model, path, conda_env=None, mlflow_model=Model()):
"""
Save an ONNX model to a path on the local file system.
:param onnx_model: ONNX model to be saved.
:param path: Local path where the model is to be saved.
:param conda_env: Either a dictionary representation of a Conda environment or the path to a
Conda environment yaml file. If provided, this decribes the environment
this model should be run in. At minimum, it should specify the dependencies
contained in :func:`get_default_conda_env()`. If `None`, the default
:func:`get_default_conda_env()` environment is added to the model.
The following is an *example* dictionary representation of a Conda
environment::
{
'name': 'mlflow-env',
'channels': ['defaults'],
'dependencies': [
'python=3.6.0',
'onnx=1.4.1',
'onnxruntime=0.3.0'
]
}
:param mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
"""
import onnx

path = os.path.abspath(path)
if os.path.exists(path):
raise MlflowException(
message="Path '{}' already exists".format(path),
error_code=RESOURCE_ALREADY_EXISTS)
os.makedirs(path)
model_data_subpath = "model.onnx"
model_data_path = os.path.join(path, model_data_subpath)

# Save onnx-model
onnx.save_model(onnx_model, model_data_path)

conda_env_subpath = "conda.yaml"
if conda_env is None:
conda_env = get_default_conda_env()
elif not isinstance(conda_env, dict):
with open(conda_env, "r") as f:
conda_env = yaml.safe_load(f)
with open(os.path.join(path, conda_env_subpath), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)

pyfunc.add_to_model(mlflow_model, loader_module="mlflow.onnx",
data=model_data_subpath, env=conda_env_subpath)
mlflow_model.add_flavor(FLAVOR_NAME, onnx_version=onnx.__version__, data=model_data_subpath)
mlflow_model.save(os.path.join(path, "MLmodel"))


def _load_model(model_file):
import onnx

onnx_model = onnx.load(model_file)
# Check Formation
onnx.checker.check_model(onnx_model)
return onnx_model


class _OnnxModelWrapper:
def __init__(self, path):
import onnxruntime
self.rt = onnxruntime.InferenceSession(path)
assert len(self.rt.get_inputs()) >= 1
self.inputs = [
(inp.name, inp.type) for inp in self.rt.get_inputs()
]
self.output_names = [
outp.name for outp in self.rt.get_outputs()
]

@staticmethod
def _cast_float64_to_float32(dataframe, column_names):
for input_name in column_names:
if dataframe[input_name].values.dtype == np.float64:
dataframe[input_name] = dataframe[input_name].values.astype(np.float32)
return dataframe

@experimental
def predict(self, dataframe):
"""
:param dataframe: A Pandas DataFrame that is converted to a collection of ONNX Runtime
inputs. If the underlying ONNX model only defines a *single* input
tensor, the DataFrame's values are converted to a NumPy array
representation using the `DataFrame.values()
<https://pandas.pydata.org/pandas-docs/stable/reference/api/
pandas.DataFrame.values.html#pandas.DataFrame.values>`_ method. If the
underlying ONNX model defines *multiple* input tensors, each column
of the DataFrame is converted to a NumPy array representation.
The corresponding NumPy array representation is then passed to the
ONNX Runtime. For more information about the ONNX Runtime, see
`<https://github.com/microsoft/onnxruntime>`_.
:return: A Pandas DataFrame output. Each column of the DataFrame corresponds to an
output tensor produced by the underlying ONNX model.
"""
# ONNXRuntime throws the following exception for some operators when the input
# dataframe contains float64 values. Unfortunately, even if the original user-supplied
# dataframe did not contain float64 values, the serialization/deserialization between the
# client and the scoring server can introduce 64-bit floats. This is being tracked in
# https://github.com/mlflow/mlflow/issues/1286. Meanwhile, we explicitly cast the input to
# 32-bit floats when needed. TODO: Remove explicit casting when issue #1286 is fixed.
if len(self.inputs) > 1:
cols = [name for (name, type) in self.inputs if type == 'tensor(float)']
else:
cols = dataframe.columns if self.inputs[0][1] == 'tensor(float)' else []

dataframe = _OnnxModelWrapper._cast_float64_to_float32(dataframe, cols)
if len(self.inputs) > 1:
feed_dict = {
name: dataframe[name].values
for (name, _) in self.inputs
}
else:
feed_dict = {self.inputs[0][0]: dataframe.values}

predicted = self.rt.run(self.output_names, feed_dict)
return pd.DataFrame.from_dict(
{c: p.reshape(-1) for (c, p) in zip(self.output_names, predicted)})


def _load_pyfunc(path):
"""
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
"""
return _OnnxModelWrapper(path)


@experimental
def load_model(model_uri):
"""
Load an ONNX model from a local file (if ``run_id`` is None) or a run.
:param model_uri: The location, in URI format, of the MLflow model, for example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see the
`Artifacts Documentation <https://www.mlflow.org/docs/latest/tracking.html#
supported-artifact-stores>`_.
:return: An ONNX model instance.
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
onnx_model_artifacts_path = os.path.join(local_model_path, flavor_conf["data"])
return _load_model(model_file=onnx_model_artifacts_path)


@experimental
def log_model(onnx_model, artifact_path, conda_env=None):
"""
Log an ONNX model as an MLflow artifact for the current run.
:param onnx_model: ONNX model to be saved.
:param artifact_path: Run-relative artifact path.
:param conda_env: Either a dictionary representation of a Conda environment or the path to a
Conda environment yaml file. If provided, this decribes the environment
this model should be run in. At minimum, it should specify the dependencies
contained in :func:`get_default_conda_env()`. If `None`, the default
:func:`get_default_conda_env()` environment is added to the model.
The following is an *example* dictionary representation of a Conda
environment::
{
'name': 'mlflow-env',
'channels': ['defaults'],
'dependencies': [
'python=3.6.0',
'onnx=1.4.1',
'onnxruntime=0.3.0'
]
}
"""
Model.log(artifact_path=artifact_path, flavor=mlflow.onnx,
onnx_model=onnx_model, conda_env=conda_env)
Loading

0 comments on commit 1d0d3e6

Please sign in to comment.