Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnx export cuda support #17183

Merged
merged 20 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/transformers/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def export_pytorch(
opset: int,
output: Path,
tokenizer: "PreTrainedTokenizer" = None,
device: str = "cpu",
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch model to an ONNX Intermediate Representation (IR)
Expand All @@ -101,6 +102,8 @@ def export_pytorch(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
device (`str`):
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
The device on which the ONNX model will be exported. Either CPU(default) or CUDA.
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand Down Expand Up @@ -136,6 +139,10 @@ def export_pytorch(
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

Expand Down Expand Up @@ -262,6 +269,7 @@ def export(
opset: int,
output: Path,
tokenizer: "PreTrainedTokenizer" = None,
device: str = "cpu",
) -> Tuple[List[str], List[str]]:
"""
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
Expand All @@ -277,6 +285,8 @@ def export(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
device (`str`):
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
The device on which the ONNX model will be exported(CUDA only works for PyTorch). The export will be done on CPU by default.
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -288,6 +298,9 @@ def export(
"Please install torch or tensorflow first."
)

if is_tf_available() and isinstance(model, TFPreTrainedModel) and device == "cuda":
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")

if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None:
Expand All @@ -310,7 +323,7 @@ def export(
)

if is_torch_available() and issubclass(type(model), PreTrainedModel):
return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer)
return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer)

Expand Down Expand Up @@ -350,6 +363,8 @@ def validate_model_outputs(
session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"])

# Compute outputs from the reference model
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
reference_model.to("cpu")
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
ref_outputs = reference_model(**reference_model_inputs)
ref_outputs_dict = {}

Expand Down
12 changes: 10 additions & 2 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported
"""

def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
from transformers.onnx import export

model_class = FeaturesManager.get_model_class_for_feature(feature)
Expand Down Expand Up @@ -272,7 +272,7 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
with NamedTemporaryFile("w") as output:
try:
onnx_inputs, onnx_outputs = export(
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device
)
validate_model_outputs(
onnx_config,
Expand All @@ -293,6 +293,14 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
@require_rjieba
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow
@require_torch
Expand Down