Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 52 additions & 15 deletions tests/run_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import tempfile
import time
import zipfile
import random
from collections import namedtuple
from distutils.version import LooseVersion


import yaml
import numpy as np
import PIL.Image
Expand All @@ -38,7 +40,7 @@
# not needed for tf-2.0
pass

from tf2onnx import tf_loader, logging, optimizer, utils, tf_utils
from tf2onnx import tf_loader, logging, optimizer, utils, tf_utils, constants
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
from tf2onnx.graph import ExternalTensorStorage
Expand All @@ -62,11 +64,13 @@ def get_beach(shape):

def get_random(shape):
"""Get random input."""
np.random.seed(42)
return np.random.sample(shape).astype(np.float32)


def get_random256(shape):
"""Get random imput between 0 and 255."""
np.random.seed(42)
return np.round(np.random.sample(shape) * 256).astype(np.float32)


Expand Down Expand Up @@ -98,6 +102,7 @@ def get_ones_int32(shape):

def get_small_rand_int32(shape):
"""Get random ints in range [1, 99]"""
np.random.seed(42)
return np.random.randint(low=1, high=100, size=shape, dtype=np.int32)

def get_zeros_then_ones(shape):
Expand All @@ -111,6 +116,15 @@ def get_wav(shape):
"""Get sound data."""
return np.sin(np.linspace(-np.pi, np.pi, shape[0]), dtype=np.float32)

def get_sentences(shape):
"""Get sentences of shape"""
words = "the quick brown fox jumps over a lazy dog".split(' ')
random.seed(42)
def get_sentence():
length = random.randint(2, 7)
return ' '.join(random.choice(words) for _ in range(length))
return np.array([get_sentence() for _ in range(np.product(shape))]).reshape(shape)


_INPUT_FUNC_MAPPING = {
"get_beach": get_beach,
Expand All @@ -124,7 +138,8 @@ def get_wav(shape):
"get_zeros_int64": get_zeros_int64,
"get_ones_int32": get_ones_int32,
"get_small_rand_int32": get_small_rand_int32,
"get_zeros_then_ones": get_zeros_then_ones
"get_zeros_then_ones": get_zeros_then_ones,
"get_sentences": get_sentences,
}


Expand All @@ -142,14 +157,18 @@ def __init__(self, url, local, input_func, input_names, output_names,
check_only_shape=False, model_type="frozen", force_input_shape=False,
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None,
skip_conversion=False, converted_model=None, signature_def=None, concrete_function=None,
large_model=False, structured_outputs=None):
large_model=False, structured_outputs=None, run_tf_frozen=None, use_custom_ops=False):
self.url = url
self.input_func = input_func
self.local = local
self.input_names = input_names
self.output_names = output_names
self.disabled = disabled
self.large_model = large_model
self.use_custom_ops = use_custom_ops
if run_tf_frozen is None:
run_tf_frozen = not self.large_model
self.run_tf_frozen = run_tf_frozen
self.structured_outputs = structured_outputs # Needed to determine output order for tf_function
self.rtol = rtol
self.atol = atol
Expand Down Expand Up @@ -242,12 +261,17 @@ def run_tensorflow(self, sess, inputs):
return result

def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
const_node_values=None):
const_node_values=None, initialized_tables=None):
"""Convert graph to tensorflow."""
if extra_opset is None:
extra_opset = []
if self.use_custom_ops:
extra_opset.append(utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1))
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
input_names=input_names, output_names=self.output_names,
const_node_values=const_node_values)
const_node_values=const_node_values,
initialized_tables=initialized_tables)

def run_caffe2(self, name, model_proto, inputs):
"""Run test again caffe2 backend."""
Expand All @@ -268,7 +292,13 @@ def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=Non
as_text=utils.is_debug_mode(),
external_tensor_storage=external_tensor_storage)
logger.info("Model saved to %s", model_path)
m = rt.InferenceSession(model_path)
if self.use_custom_ops:
from ortcustomops import get_library_path
opt = rt.SessionOptions()
opt.register_custom_ops_library(get_library_path())
m = rt.InferenceSession(model_path, opt)
else:
m = rt.InferenceSession(model_path)
results = m.run(self.output_names, inputs)
if self.perf:
start = time.time()
Expand Down Expand Up @@ -303,19 +333,21 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops

logger.info("Load model from %s", model_path)
input_names = list(self.input_names.keys())
initialized_tables = {}
outputs = self.output_names
if self.model_type in ["checkpoint"]:
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
elif self.model_type in ["saved_model"]:
loaded = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag, self.signatures,
self.concrete_function, self.large_model,
return_concrete_func=self.large_model)
if self.large_model:
return_concrete_func=not self.run_tf_frozen,
return_initialized_tables=True)
if not self.run_tf_frozen:
# Must maintain ref to imported since concrete_func uses weak refs
# pylint: disable=unused-variable
graph_def, input_names, outputs, concrete_func, imported = loaded
graph_def, input_names, outputs, concrete_func, imported, initialized_tables = loaded
else:
graph_def, input_names, outputs = loaded
graph_def, input_names, outputs, initialized_tables = loaded
elif self.model_type in ["keras"]:
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
else:
Expand All @@ -324,7 +356,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
if utils.is_debug_mode():
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)

if self.large_model:
if not self.run_tf_frozen:
inputs = {}
for k in input_names:
v = self.input_names[k]
Expand Down Expand Up @@ -368,7 +400,10 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
np_value.dtype)
inputs[k] = np_value.astype(expected_dtype)
else:
inputs[k] = self.make_input(v).astype(expected_dtype)
if expected_dtype == "string":
inputs[k] = self.make_input(v).astype(np.str).astype(np.object)
else:
inputs[k] = self.make_input(v).astype(expected_dtype)

if self.force_input_shape:
for k, v in inputs.items():
Expand All @@ -377,7 +412,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
# run the model with tensorflow
if self.skip_tensorflow:
logger.info("TensorFlow SKIPPED")
elif not self.large_model:
elif self.run_tf_frozen:
tf_results = self.run_tensorflow(sess, inputs)
logger.info("TensorFlow OK")

Expand All @@ -395,7 +430,8 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
# convert model to onnx
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
shape_override=shape_override, input_names=inputs.keys(),
const_node_values=const_node_values)
const_node_values=const_node_values,
initialized_tables=initialized_tables)
onnx_graph = optimizer.optimize_graph(onnx_graph)
print("ONNX", onnx_graph.dump_node_statistics())
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
Expand Down Expand Up @@ -559,7 +595,8 @@ def load_tests_from_yaml(path):
kwargs = {}
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type", "concrete_function",
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag", "skip_conversion",
"converted_model", "signature_def", "large_model", "structured_outputs"]:
"converted_model", "signature_def", "large_model", "structured_outputs", "run_tf_frozen",
"use_custom_ops"]:
if settings.get(kw) is not None:
kwargs[kw] = settings[kw]

Expand Down