Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check TF ops for ONNX compliance #10025

Merged
merged 18 commits into from
Feb 15, 2021
Merged

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Feb 5, 2021

What does this PR do?

This PR aims to check if a model is compliant with ONNX Opset12 by adding a quick test and a script. The script is only for testing a saved model while the quick test aims to be run over a manually built graph. For now, only BERT is forced to be compliant with ONNX, but the test can be unlocked for any other model.

The logic can also be extended to any other framework/SDK we might think of, such as TFLite or NNAPI.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but let's wait for Morgan's review. Would also be helpful to know which models should have the test activated (GPT-2 for instance?)

@jplu
Copy link
Contributor Author

jplu commented Feb 5, 2021

That's the moment to create that list :)

@mfuntowicz
Copy link
Member

mfuntowicz commented Feb 8, 2021

Really like the idea of improving our ONNX compatibility in a more reliable way.

In this sense, I'm not sure this is the easiest way for approaching the problem.
May be it would be more suitable to just attempt to export the model through keras2onnx and report errors. This can also allow us to more easily test compatibility with various ONNX opset (10 is the minimum required).

We already have the keras2onnx dependency when using the optional requirements onnxruntime

Also, regarding the list of model to be supported I think we would like to have:

  • BERT
  • GPT2
  • BART (cc @Narsil wdyt?)

@jplu
Copy link
Contributor Author

jplu commented Feb 8, 2021

In this sense, I'm not sure this is the easiest way for approaching the problem.
May be it would be more suitable to just attempt to export the model through keras2onnx and report errors. This can also allow us to more easily test compatibility with various ONNX opset (10 is the minimum required).

I'm not in favour of running the export directly in the tests, as it is less flexible and not a compatible solution with other frameworks/SDKs. We can add any other opsets without problems with the proposed approach, but I don't think that going below the 12 is good idea. Only a few of the current models are compliant with opset < 12. Also, the default in the convert script is 11, not 10, so maybe we can propose opset 11 to be aligned. I think the proposed models below are compliant but we will largely reduce the number of compliant models with ONNX.

Also, regarding the list of model to be supported I think we would like to have:
BERT
GPT2
BART (cc @Narsil wdyt?)

These three are ok on my side!

@jplu
Copy link
Contributor Author

jplu commented Feb 8, 2021

Add BART and GPT2 as a mandatory green test.

@mfuntowicz
Copy link
Member

mfuntowicz commented Feb 8, 2021

Testing ONNX operator support might be more complicated than this.

Each operator in itself supports a set of different shape(s) as input/output combined with different data types and various dynamic axis support ...

I would go for the easiest solution, well tested of using the official converter to report incompatibilities.

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR is fine, but I want to stress some points about the claims of those tests, which do not reflect (at least not to my knowledge) the complexity of optimizing a graph for inference.

It also feels odd to add many empty tests for many models that are not used.

As for BART, it's a good example of encoder-decoder so having it might help improve confidence in ONNX exportability but don't think it's required for this PR.

"VarIsInitializedOp",
]

with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to depend on an external file within a test ?
Doesn't it make sense to include that directly as a Python dict ?

Just feels simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is easier to maintain than a dict. Also this list should be shared across the check script.

for model_class in self.all_model_classes:
model_op_names = set()

with tf.Graph().as_default() as g:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it possible to reuse onnx_compliancy ? They seem different and it feels like an open opportunity for error between both

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list is not the same. The list in onnx_compliancy is bigger because the script checks from a SavedModel, and inside a SavedModel you have operators that are added that are specific to a SavedModel. Here we check from the graph created on the fly, not from a SavedModel, so the operators that are specific to a SavedModel are not needed here.


with tf.Graph().as_default() as g:
model = model_class(config)
model(model.dummy_inputs)
Copy link
Contributor

@Narsil Narsil Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again not familiar with TF way of working, but the actual inputs in PT do change quite a bit the actual traced graph.

That means that use_cache graph vs first_pass graph look quite different. Also setting variable seq_length where it can be fixed (for instance input_ids is necessarily [B, 1] for use_cache graph on decoder) can link to greatly different performance later down the road.

What I'm trying to say is that this test will probably check that the Ops used in TF are valid for some ONNX opset, it does not by any means that check it can/will export the best production ready graph.

And the real hot path in production is almost always, decoder-only with use_cache (even input_ids [1, 1]) within a generation loop (I don't think TF has the generation loop optimized yet.)

Copy link
Contributor Author

@jplu jplu Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are not testing the graph, we are loading the entire list of operators, the graph here is not optimized. To give you an example, This test, for BERT, loads the > 5000 operators, while the optimised graph for inference is only around 1200 nodes. The role of this test is just to be sure to have the entire list of used operators inside the list proposed here https://github.com/onnx/tensorflow-onnx/blob/master/support_status.md

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know, I was just emphasizing it.
Also optimized graph of ONNX can only go so far. It cannot know about past_values if they are not passed within those dummy inputs.

unoptimized small graph > optimized big graph

  • big as in, large sequences length, not sheer node number
  • Again talking about PT here, I didn't check with TF yet, but results are probably similar.

Copy link
Contributor Author

@jplu jplu Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a misunderstanding, this test is only here to say "this TF op is also implemented in ONNX" nothing more. And not for testing if the optimized ONNX graph will work as expected or not.

If you and Morgan prefer I can add a slow test that will run the pipeline:

  1. SavedModel creation
  2. ONNX conversion with keras2onnx
  3. Run an inference with onnxruntime

Copy link
Contributor

@Narsil Narsil Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a misunderstanding, this test is only here to say "this TF op is also implemented in ONNX" nothing more.

There is no misunderstanding, I was trying to say what you just said.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so if you are trying to say the same thing there is no problem then^^

@jplu
Copy link
Contributor Author

jplu commented Feb 8, 2021

The problem with the solution to use the converter is that we cannot have the full list of incompatible operators, it will stop at the first encounter one, which would be too much annoying IMO. I think we can also assume that as long as the operator belongs to this list https://github.com/onnx/tensorflow-onnx/blob/master/support_status.md it is compliant. Until now, this assumption is true for all of our models.

Unless you know a case for which it is not true?

Also, I'm afraid to add a dependency to the onnxruntime would switch the quick test into a slow test, which reduces the traceability of a potential change that will break it.

If @LysandreJik, @sgugger and @patrickvonplaten agree on making the TF tests dependent on the two keras2onnx and onnxruntime packages, I can add a slow test that will run the following pipeline:

  1. Create a SavedModel
  2. Convert this SavedModel into ONNX with keras2onnx
  3. Run the converted model with onnxruntime

@LysandreJik
Copy link
Member

LysandreJik commented Feb 8, 2021

We can add a test depending on keras2onnx or onnxruntime with a @require_onnx decorator. If you decide to go down this road, according to the time spent doing those tests, we'll probably put them in the slow suite (which is okay, no need to test that the model opsets on each PR)

@jplu
Copy link
Contributor Author

jplu commented Feb 9, 2021

I like the idea to add a decorator. I will add a slow test doing this in addition to the quick test.

@jplu
Copy link
Contributor Author

jplu commented Feb 9, 2021

I have reworked the quick test. Now, we can easily specify against which opset we want to test a model to be compliant. In the onnx.json file, all the operators are split in multiple opset, where each of them corresponds to the list of operators implemented in it. This should be way much easier to maintain and more flexible to use.

In addition to this I have added slow test that runs a complete pipeline of "keras model -> ONNX model -> optimized ONNX model -> quantized ONNX model".

@jplu
Copy link
Contributor Author

jplu commented Feb 9, 2021

As proposed by @mfuntowicz I switched the min required opset version from 12 to 10 for BERT, GPT2 and BART.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I like that you use keras2onnx directly.

Will wait for @mfuntowicz's review before merging.

  • Do you have an idea of how long the slow tests take ?
  • According to the information gathered, would it be possible (in a next PR) to have a doc referencing the opset compliancy/onnx support for each model?

src/transformers/file_utils.py Outdated Show resolved Hide resolved
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
@jplu
Copy link
Contributor Author

jplu commented Feb 10, 2021

Do you have an idea of how long the slow tests take ?

Depending of the model between 1 and 5min.

According to the information gathered, would it be possible (in a next PR) to have a doc referencing the opset compliancy/onnx support for each model?

Do you mean to have an entire page about ONNX? Or just to add a paragraph in the doc of every model about it?

I think it is also important to mention that the model TFGPT2ForSequenceClassification cannot be converted into ONNX for now. The reason is because of the tf.map_fn function, that internally creates a tf.while with an iterator of type tf.variant which is not allowed in ONNX.

onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)

with tempfile.TemporaryDirectory() as tmpdirname:
keras2onnx.save_model(onnx_model, os.path.join(tmpdirname, "model.onnx"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will work only for model < 2Gb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model configs in the test make the size of the models tiny. So as it is only for testing, this case should never happen.

Comment on lines 270 to 278
onnxruntime.InferenceSession(os.path.join(tmpdirname, "model.onnx"), sess_option)
onnx_model = onnx.load(os.path.join(tmpdirname, "model-optimized.onnx"))
quantized_model = quantize(
model=onnx_model,
quantization_mode=QuantizationMode.IntegerOps,
force_fusions=True,
symmetric_weight=True,
)
onnx.save_model(quantized_model, os.path.join(tmpdirname, "model-quantized.onnx"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it's the old way of quantizing the model, and not the one recommended anymore which might lead to strange result.

I would not try quantization here as it's hardware dependant and we have no control on the machine actually running this test.

with tempfile.TemporaryDirectory() as tmpdirname:
keras2onnx.save_model(onnx_model, os.path.join(tmpdirname, "model.onnx"))
sess_option = onnxruntime.SessionOptions()
sess_option.optimized_model_filepath = os.path.join(tmpdirname, "model-optimized.onnx")
Copy link
Member

@mfuntowicz mfuntowicz Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this, no need to save the optimized graph on disk.

@jplu
Copy link
Contributor Author

jplu commented Feb 10, 2021

LGTM on my side!

@LysandreJik I have fixed the issue with TFGPT2ForSequenceClassification, so now it is compliant with ONNX.

@mfuntowicz I should have addressed your comments, please double check ^^

@mfuntowicz
Copy link
Member

LGTM 👍🏻

@jplu
Copy link
Contributor Author

jplu commented Feb 15, 2021

@LysandreJik Feel free to merge if the recent changes look ok for you!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! The changes you applied for TF GPT-2 could also be applied to other causal models that have this approach in their sequence classification architecture (such as CTRL), but they're not ONNX compliant so this can be done another time.

@jplu
Copy link
Contributor Author

jplu commented Feb 15, 2021

@LysandreJik Yes, this is exactly that :) I plan to apply this update to the other causal models one by one 😉

@LysandreJik LysandreJik merged commit c8d3fa0 into huggingface:master Feb 15, 2021
@jplu jplu deleted the check-tf-ops branch February 15, 2021 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants