Skip to content

enforce onnx conversion in CI #3600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ jobs:
pip_constraints:
type: string
description: Constraints file that is passed to "pip install". We constraint older versions of libraries for older python runtime, in order to help ensure compatibility.
enforce_onnx_conversion:
type: integer
default: 0
description: Whether to raise an exception if ONNX models couldn't be saved.
executor: << parameters.executor >>
working_directory: ~/repo

# Run additional numpy checks on unit tests
environment:
TEST_ENFORCE_NUMPY_FLOAT32: 1
TEST_ENFORCE_ONNX_CONVERSION: << parameters.enforce_onnx_conversion >>

steps:
- checkout
Expand Down Expand Up @@ -217,6 +222,8 @@ workflows:
pyversion: 3.7.3
# Test python 3.7 with the newest supported versions
pip_constraints: test_constraints_max_tf1_version.txt
# Make sure ONNX conversion passes here (recent version of tensorflow 1.x)
enforce_onnx_conversion: 1
- build_python:
name: python_3.7.3+tf2
executor: python373
Expand Down
2 changes: 1 addition & 1 deletion docs/Unity-Inference-Engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ There are currently two supported model formats:
* ONNX (`.onnx`) files use an [industry-standard open format](https://onnx.ai/about.html) produced by the [tf2onnx package](https://github.com/onnx/tensorflow-onnx).

Export to ONNX is currently considered beta. To enable it, make sure `tf2onnx>=1.5.5` is installed in pip.
tf2onnx does not currently support tensorflow 2.0.0 or later.
tf2onnx does not currently support tensorflow 2.0.0 or later, or earlier than 1.12.0.

## Using the Unity Inference Engine

Expand Down
55 changes: 43 additions & 12 deletions ml-agents/mlagents/model_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from distutils.util import strtobool
import os
import logging
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion

try:
import onnx
Expand All @@ -18,6 +21,11 @@
from tensorflow.python.framework import graph_util
from mlagents.trainers import tensorflow_to_barracuda as tf2bc

if LooseVersion(tf.__version__) < LooseVersion("1.12.0"):
# ONNX is only tested on 1.12.0 and later
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX_EXPORT_ENABLED = False


logger = logging.getLogger("mlagents.trainers")

POSSIBLE_INPUT_NODES = frozenset(
Expand Down Expand Up @@ -67,18 +75,28 @@ def export_policy_model(
logger.info(f"Exported {settings.model_path}.nn file")

# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED and settings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
onnx_output_path = settings.model_path + ".onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
if ONNX_EXPORT_ENABLED:
if settings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
onnx_output_path = settings.model_path + ".onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
# Make conversion errors fatal depending on environment variables (only done during CI)
if _enforce_onnx_conversion():
raise
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
)

else:
if _enforce_onnx_conversion():
raise RuntimeError(
"ONNX conversion enforced, but couldn't import dependencies."
)


Expand Down Expand Up @@ -203,3 +221,16 @@ def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str
for n in nodes:
logger.info("\t" + n)
return nodes


def _enforce_onnx_conversion() -> bool:
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION"
if env_var_name not in os.environ:
return False

val = os.environ[env_var_name]
try:
# This handles e.g. "false" converting reasonably to False
return strtobool(val)
except Exception:
return False
3 changes: 1 addition & 2 deletions test_constraints_max_tf1_version.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
# For projects with upper bounds, we should periodically update this list to the latest release version
grpcio>=1.23.0
numpy>=1.17.2
# Temporary workaround for https://github.com/tensorflow/tensorflow/issues/36179 and https://github.com/tensorflow/tensorflow/issues/36188
tensorflow>=1.14.0,<1.15.1
tensorflow>=1.15.2,<2.0.0
h5py>=2.10.0
2 changes: 1 addition & 1 deletion test_constraints_min_version.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ grpcio==1.11.0
numpy==1.14.1
Pillow==4.2.1
protobuf==3.6
tensorflow==1.7
tensorflow==1.7.0
h5py==2.9.0
4 changes: 1 addition & 3 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,4 @@ pytest>4.0.0,<6.0.0
pytest-cov==2.6.1
pytest-xdist

# Tests install onnx and tf2onnx, but this doesn't support tensorflow>=2.0.0
# Since we test tensorflow2.0 with python3.7, exclude it based on the python version
tf2onnx>=1.5.5; python_version < '3.7'
tf2onnx>=1.5.5