-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Check TF ops for ONNX compliance #10025
Changes from all commits
aefbe8d
c7bdca7
9fb8848
ee23cbe
1fcd76c
e88b6e8
7d5d261
c5befe1
f3d608e
84f0587
fed37cb
f9a2ac5
8e227f5
ac2a647
56f4cec
59bbcf7
8e71608
b7b36df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
import copy | ||
import inspect | ||
import json | ||
import os | ||
import random | ||
import tempfile | ||
|
@@ -24,7 +25,7 @@ | |
from typing import List, Tuple | ||
|
||
from transformers import is_tf_available | ||
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow | ||
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow | ||
|
||
|
||
if is_tf_available(): | ||
|
@@ -201,6 +202,67 @@ def test_saved_model_creation(self): | |
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") | ||
self.assertTrue(os.path.exists(saved_model_dir)) | ||
|
||
def test_onnx_compliancy(self): | ||
if not self.test_onnx: | ||
return | ||
|
||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
INTERNAL_OPS = [ | ||
"Assert", | ||
"AssignVariableOp", | ||
"EmptyTensorList", | ||
"ReadVariableOp", | ||
"ResourceGather", | ||
"TruncatedNormal", | ||
"VarHandleOp", | ||
"VarIsInitializedOp", | ||
] | ||
onnx_ops = [] | ||
|
||
with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f: | ||
onnx_opsets = json.load(f)["opsets"] | ||
|
||
for i in range(1, self.onnx_min_opset + 1): | ||
onnx_ops.extend(onnx_opsets[str(i)]) | ||
|
||
for model_class in self.all_model_classes: | ||
model_op_names = set() | ||
|
||
with tf.Graph().as_default() as g: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it possible to reuse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list is not the same. The list in |
||
model = model_class(config) | ||
model(model.dummy_inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again not familiar with TF way of working, but the actual inputs in PT do change quite a bit the actual traced graph. That means that What I'm trying to say is that this test will probably check that the Ops used in TF are valid for some ONNX opset, it does not by any means that check it can/will export the best production ready graph. And the real hot path in production is almost always, decoder-only with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we are not testing the graph, we are loading the entire list of operators, the graph here is not optimized. To give you an example, This test, for BERT, loads the > 5000 operators, while the optimised graph for inference is only around 1200 nodes. The role of this test is just to be sure to have the entire list of used operators inside the list proposed here https://github.com/onnx/tensorflow-onnx/blob/master/support_status.md There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I know, I was just emphasizing it. unoptimized small graph > optimized big graph
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is a misunderstanding, this test is only here to say "this TF op is also implemented in ONNX" nothing more. And not for testing if the optimized ONNX graph will work as expected or not. If you and Morgan prefer I can add a slow test that will run the pipeline:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is no misunderstanding, I was trying to say what you just said. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so if you are trying to say the same thing there is no problem then^^ |
||
|
||
for op in g.get_operations(): | ||
model_op_names.add(op.node_def.op) | ||
|
||
model_op_names = sorted(model_op_names) | ||
incompatible_ops = [] | ||
|
||
for op in model_op_names: | ||
if op not in onnx_ops and op not in INTERNAL_OPS: | ||
incompatible_ops.append(op) | ||
|
||
self.assertEqual(len(incompatible_ops), 0, incompatible_ops) | ||
|
||
@require_onnx | ||
@slow | ||
def test_onnx_runtime_optimize(self): | ||
if not self.test_onnx: | ||
return | ||
|
||
import keras2onnx | ||
import onnxruntime | ||
|
||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
||
for model_class in self.all_model_classes: | ||
model = model_class(config) | ||
model(model.dummy_inputs) | ||
|
||
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset) | ||
|
||
onnxruntime.InferenceSession(onnx_model.SerializeToString()) | ||
|
||
@slow | ||
def test_saved_model_creation_extended(self): | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you want to depend on an external file within a test ?
Doesn't it make sense to include that directly as a Python dict ?
Just feels simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is easier to maintain than a dict. Also this list should be shared across the check script.