Skip to content

Commit

Permalink
M2M100 support for ONNX export (huggingface#15193)
Browse files Browse the repository at this point in the history
* Add M2M100 support for ONNX export

* Delete useless imports

* Add M2M100 to tests

* Fix protobuf issue
  • Loading branch information
michaelbenayoun authored Mar 2, 2022
1 parent d1a2907 commit 4bfe75b
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 29 deletions.
29 changes: 15 additions & 14 deletions docs/source/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
- GPT Neo
- I-BERT
- LayoutLM
- M2M100
- Marian
- mBART
- OpenAI GPT-2
Expand Down Expand Up @@ -584,12 +585,12 @@ traced_model(tokens_tensor, segments_tensors)

### Deploying HuggingFace TorchScript models on AWS using the Neuron SDK

AWS introduced the [Amazon EC2 Inf1](https://aws.amazon.com/ec2/instance-types/inf1/)
instance family for low cost, high performance machine learning inference in the cloud.
The Inf1 instances are powered by the AWS Inferentia chip, a custom-built hardware accelerator,
specializing in deep learning inferencing workloads.
[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/#)
is the SDK for Inferentia that supports tracing and optimizing transformers models for
AWS introduced the [Amazon EC2 Inf1](https://aws.amazon.com/ec2/instance-types/inf1/)
instance family for low cost, high performance machine learning inference in the cloud.
The Inf1 instances are powered by the AWS Inferentia chip, a custom-built hardware accelerator,
specializing in deep learning inferencing workloads.
[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/#)
is the SDK for Inferentia that supports tracing and optimizing transformers models for
deployment on Inf1. The Neuron SDK provides:


Expand All @@ -600,13 +601,13 @@ deployment on Inf1. The Neuron SDK provides:

#### Implications

Transformers Models based on the [BERT (Bidirectional Encoder Representations from Transformers)](https://huggingface.co/docs/transformers/master/model_doc/bert)
Transformers Models based on the [BERT (Bidirectional Encoder Representations from Transformers)](https://huggingface.co/docs/transformers/master/model_doc/bert)
architecture, or its variants such as [distilBERT](https://huggingface.co/docs/transformers/master/model_doc/distilbert)
and [roBERTa](https://huggingface.co/docs/transformers/master/model_doc/roberta)
will run best on Inf1 for non-generative tasks such as Extractive Question Answering,
and [roBERTa](https://huggingface.co/docs/transformers/master/model_doc/roberta)
will run best on Inf1 for non-generative tasks such as Extractive Question Answering,
Sequence Classification, Token Classification. Alternatively, text generation
tasks can be adapted to run on Inf1, according to this [AWS Neuron MarianMT tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/transformers-marianmt.html).
More information about models that can be converted out of the box on Inferentia can be
tasks can be adapted to run on Inf1, according to this [AWS Neuron MarianMT tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/transformers-marianmt.html).
More information about models that can be converted out of the box on Inferentia can be
found in the [Model Architecture Fit section of the Neuron documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/models/models-inferentia.html#models-inferentia).

#### Dependencies
Expand All @@ -618,8 +619,8 @@ Using AWS Neuron to convert models requires the following dependencies and envir

#### Converting a Model for AWS Neuron

Using the same script as in [Using TorchScript in Python](https://huggingface.co/docs/transformers/master/en/serialization#using-torchscript-in-python)
to trace a "BertModel", you import `torch.neuron` framework extension to access
Using the same script as in [Using TorchScript in Python](https://huggingface.co/docs/transformers/master/en/serialization#using-torchscript-in-python)
to trace a "BertModel", you import `torch.neuron` framework extension to access
the components of the Neuron SDK through a Python API.

```python
Expand All @@ -643,5 +644,5 @@ torch.neuron.trace(model, [token_tensor, segments_tensors])

This change enables Neuron SDK to trace the model and optimize it to run in Inf1 instances.

To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates,
To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates,
please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html).
4 changes: 2 additions & 2 deletions src/transformers/models/m2m_100/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config", "M2M100OnnxConfig"],
"tokenization_m2m_100": ["M2M100Tokenizer"],
}

Expand All @@ -36,7 +36,7 @@


if TYPE_CHECKING:
from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config, M2M100OnnxConfig
from .tokenization_m2m_100 import M2M100Tokenizer

if is_torch_available():
Expand Down
129 changes: 129 additions & 0 deletions src/transformers/models/m2m_100/configuration_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" M2M100 model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional

from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import logging


Expand Down Expand Up @@ -153,3 +159,126 @@ def __init__(
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)


class M2M100OnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)

if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}

if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
return common_inputs

# Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
# A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
# answering are not supported for M2M100, but this name is preserved to be able to check that the copy matches what
# was done for BART so that it can be updated if need be.
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# Copied from OnnxConfig.generate_dummy_inputs
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
)

# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
)

# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
return common_inputs

# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, seq_length, is_pair, framework
)

# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)

if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)

common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)

common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"

for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs

generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
39 changes: 26 additions & 13 deletions src/transformers/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,34 @@ def export_pytorch(

# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
if parse(torch.__version__) < parse("1.10"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
onnx_export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
try:
onnx_export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={
name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())
},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
except RuntimeError as err:
message = str(err)
if (
message
== "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter."
):
message = "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter or try with torch 1.10+."
raise RuntimeError(message)
else:
raise err
else:
onnx_export(
model,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
Expand Down Expand Up @@ -184,6 +185,9 @@ class FeaturesManager:
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
),
"m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
),
"roberta": supported_features_mapping(
"default",
"masked-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def test_values_override(self):
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
("m2m-100", "facebook/m2m100_418M"),
}

TENSORFLOW_EXPORT_DEFAULT_MODELS = {
Expand Down

0 comments on commit 4bfe75b

Please sign in to comment.