Skip to content

Commit

Permalink
Check TF ops for ONNX compliance (#10025)
Browse files Browse the repository at this point in the history
* Add check-ops script

* Finish to implement check_tf_ops and start the test

* Make the test mandatory only for BERT

* Update tf_ops folder

* Remove useless classes

* Add the ONNX test for GPT2 and BART

* Add a onnxruntime slow test + better opset flexibility

* Fix test + apply style

* fix tests

* Switch min opset from 12 to 10

* Update src/transformers/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Fix GPT2

* Remove extra shape_list usage

* Fix GPT2

* Address Morgan's comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
jplu and LysandreJik authored Feb 15, 2021
1 parent 93bd2f7 commit c8d3fa0
Show file tree
Hide file tree
Showing 33 changed files with 468 additions and 17 deletions.
14 changes: 14 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@
_faiss_available = False


_onnx_available = (
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
)
try:
_onxx_version = importlib_metadata.version("onnx")
logger.debug(f"Successfully imported onnx version {_onxx_version}")
except importlib_metadata.PackageNotFoundError:
_onnx_available = False


_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
_scatter_version = importlib_metadata.version("torch_scatter")
Expand Down Expand Up @@ -230,6 +240,10 @@ def is_tf_available():
return _tf_available


def is_onnx_available():
return _onnx_available


def is_flax_available():
return _flax_available

Expand Down
19 changes: 3 additions & 16 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,16 +1030,7 @@ def call(
)
- 1
)

def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)

result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1
logger.warning(
Expand All @@ -1049,16 +1040,12 @@ def get_seq_element(sequence_position, input_batch):
loss = None

if inputs["labels"] is not None:
if input_ids is not None:
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
else:
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
self.config.pad_token_id is not None or logits_shape[0] == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."

if not tf.is_tensor(sequence_lengths):
in_logits = logits[0:batch_size, sequence_lengths]
in_logits = logits[0 : logits_shape[0], sequence_lengths]

loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_datasets_available,
is_faiss_available,
is_flax_available,
is_onnx_available,
is_pandas_available,
is_scatter_available,
is_sentencepiece_available,
Expand Down Expand Up @@ -160,6 +161,13 @@ def require_git_lfs(test_case):
return test_case


def require_onnx(test_case):
if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case)
else:
return test_case


def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch.
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFAlbertModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = True
onnx_min_opset = 10

def setUp(self):
self.model_tester = TFBartModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = True
onnx_min_opset = 10

# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFBlenderbotModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_blenderbot_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)
Expand Down
64 changes: 63 additions & 1 deletion tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import copy
import inspect
import json
import os
import random
import tempfile
Expand All @@ -24,7 +25,7 @@
from typing import List, Tuple

from transformers import is_tf_available
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow


if is_tf_available():
Expand Down Expand Up @@ -201,6 +202,67 @@ def test_saved_model_creation(self):
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir))

def test_onnx_compliancy(self):
if not self.test_onnx:
return

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"ReadVariableOp",
"ResourceGather",
"TruncatedNormal",
"VarHandleOp",
"VarIsInitializedOp",
]
onnx_ops = []

with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
onnx_opsets = json.load(f)["opsets"]

for i in range(1, self.onnx_min_opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])

for model_class in self.all_model_classes:
model_op_names = set()

with tf.Graph().as_default() as g:
model = model_class(config)
model(model.dummy_inputs)

for op in g.get_operations():
model_op_names.add(op.node_def.op)

model_op_names = sorted(model_op_names)
incompatible_ops = []

for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)

self.assertEqual(len(incompatible_ops), 0, incompatible_ops)

@require_onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return

import keras2onnx
import onnxruntime

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
model(model.dummy_inputs)

onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)

onnxruntime.InferenceSession(onnx_model.SerializeToString())

@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
)
test_pruning = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFConvBertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFCTRLModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
else None
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFDistilBertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
test_missing_keys = False
test_pruning = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFDPRModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFElectraModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFFlaubertModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self)
Expand Down Expand Up @@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self, base=True)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = True
onnx_min_opset = 10

def setUp(self):
self.model_tester = TFGPT2ModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFLEDModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFLongformerModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFLxmertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFMarianModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFMBartModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

class TFMobileBertModelTester(object):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFMPNetModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFOpenAIGPTModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False

def setUp(self):
self.model_tester = TFPegasusModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFRobertaModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFT5ModelTester(self)
Expand Down Expand Up @@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFT5EncoderOnlyModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
test_resize_embeddings = False
test_head_masking = False
test_onnx = False

def setUp(self):
self.model_tester = TFTransfoXLModelTester(self)
Expand Down
Loading

0 comments on commit c8d3fa0

Please sign in to comment.