Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check TF ops for ONNX compliance #10025

Merged
merged 18 commits into from
Feb 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@
_faiss_available = False


_onnx_available = (
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
)
try:
_onxx_version = importlib_metadata.version("onnx")
logger.debug(f"Successfully imported onnx version {_onxx_version}")
except importlib_metadata.PackageNotFoundError:
_onnx_available = False


_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
_scatter_version = importlib_metadata.version("torch_scatter")
Expand Down Expand Up @@ -226,6 +236,10 @@ def is_tf_available():
return _tf_available


def is_onnx_available():
return _onnx_available


def is_flax_available():
return _flax_available

Expand Down
19 changes: 3 additions & 16 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,16 +1030,7 @@ def call(
)
- 1
)

def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)

result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1
logger.warning(
Expand All @@ -1049,16 +1040,12 @@ def get_seq_element(sequence_position, input_batch):
loss = None

if inputs["labels"] is not None:
if input_ids is not None:
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
else:
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
self.config.pad_token_id is not None or logits_shape[0] == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."

if not tf.is_tensor(sequence_lengths):
in_logits = logits[0:batch_size, sequence_lengths]
in_logits = logits[0 : logits_shape[0], sequence_lengths]

loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_datasets_available,
is_faiss_available,
is_flax_available,
is_onnx_available,
is_pandas_available,
is_scatter_available,
is_sentencepiece_available,
Expand Down Expand Up @@ -160,6 +161,13 @@ def require_git_lfs(test_case):
return test_case


def require_onnx(test_case):
if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case)
else:
return test_case


def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch.
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFAlbertModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = True
onnx_min_opset = 10

def setUp(self):
self.model_tester = TFBartModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = True
onnx_min_opset = 10

# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFBlenderbotModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_blenderbot_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)
Expand Down
64 changes: 63 additions & 1 deletion tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import copy
import inspect
import json
import os
import random
import tempfile
Expand All @@ -24,7 +25,7 @@
from typing import List, Tuple

from transformers import is_tf_available
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow


if is_tf_available():
Expand Down Expand Up @@ -201,6 +202,67 @@ def test_saved_model_creation(self):
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir))

def test_onnx_compliancy(self):
if not self.test_onnx:
return

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"ReadVariableOp",
"ResourceGather",
"TruncatedNormal",
"VarHandleOp",
"VarIsInitializedOp",
]
onnx_ops = []

with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to depend on an external file within a test ?
Doesn't it make sense to include that directly as a Python dict ?

Just feels simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is easier to maintain than a dict. Also this list should be shared across the check script.

onnx_opsets = json.load(f)["opsets"]

for i in range(1, self.onnx_min_opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])

for model_class in self.all_model_classes:
model_op_names = set()

with tf.Graph().as_default() as g:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it possible to reuse onnx_compliancy ? They seem different and it feels like an open opportunity for error between both

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list is not the same. The list in onnx_compliancy is bigger because the script checks from a SavedModel, and inside a SavedModel you have operators that are added that are specific to a SavedModel. Here we check from the graph created on the fly, not from a SavedModel, so the operators that are specific to a SavedModel are not needed here.

model = model_class(config)
model(model.dummy_inputs)
Copy link
Contributor

@Narsil Narsil Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again not familiar with TF way of working, but the actual inputs in PT do change quite a bit the actual traced graph.

That means that use_cache graph vs first_pass graph look quite different. Also setting variable seq_length where it can be fixed (for instance input_ids is necessarily [B, 1] for use_cache graph on decoder) can link to greatly different performance later down the road.

What I'm trying to say is that this test will probably check that the Ops used in TF are valid for some ONNX opset, it does not by any means that check it can/will export the best production ready graph.

And the real hot path in production is almost always, decoder-only with use_cache (even input_ids [1, 1]) within a generation loop (I don't think TF has the generation loop optimized yet.)

Copy link
Contributor Author

@jplu jplu Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are not testing the graph, we are loading the entire list of operators, the graph here is not optimized. To give you an example, This test, for BERT, loads the > 5000 operators, while the optimised graph for inference is only around 1200 nodes. The role of this test is just to be sure to have the entire list of used operators inside the list proposed here https://github.com/onnx/tensorflow-onnx/blob/master/support_status.md

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know, I was just emphasizing it.
Also optimized graph of ONNX can only go so far. It cannot know about past_values if they are not passed within those dummy inputs.

unoptimized small graph > optimized big graph

  • big as in, large sequences length, not sheer node number
  • Again talking about PT here, I didn't check with TF yet, but results are probably similar.

Copy link
Contributor Author

@jplu jplu Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a misunderstanding, this test is only here to say "this TF op is also implemented in ONNX" nothing more. And not for testing if the optimized ONNX graph will work as expected or not.

If you and Morgan prefer I can add a slow test that will run the pipeline:

  1. SavedModel creation
  2. ONNX conversion with keras2onnx
  3. Run an inference with onnxruntime

Copy link
Contributor

@Narsil Narsil Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a misunderstanding, this test is only here to say "this TF op is also implemented in ONNX" nothing more.

There is no misunderstanding, I was trying to say what you just said.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so if you are trying to say the same thing there is no problem then^^


for op in g.get_operations():
model_op_names.add(op.node_def.op)

model_op_names = sorted(model_op_names)
incompatible_ops = []

for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)

self.assertEqual(len(incompatible_ops), 0, incompatible_ops)

@require_onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return

import keras2onnx
import onnxruntime

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
model(model.dummy_inputs)

onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)

onnxruntime.InferenceSession(onnx_model.SerializeToString())

@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
)
test_pruning = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFConvBertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFCTRLModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
else None
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFDistilBertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
test_missing_keys = False
test_pruning = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFDPRModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFElectraModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFFlaubertModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self)
Expand Down Expand Up @@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self, base=True)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = True
onnx_min_opset = 10

def setUp(self):
self.model_tester = TFGPT2ModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFLEDModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFLongformerModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFLxmertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFMarianModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFMBartModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

class TFMobileBertModelTester(object):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFMPNetModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFOpenAIGPTModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFPegasusModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFRobertaModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFT5ModelTester(self)
Expand Down Expand Up @@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFT5EncoderOnlyModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
test_resize_embeddings = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFTransfoXLModelTester(self)
Expand Down
Loading