-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(components): PyTorch - Convert to ONNX from PyTorch ScriptModule (…
- Loading branch information
Showing
2 changed files
with
99 additions
and
0 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
components/PyTorch/Convert_to_OnnxModel_from_PyTorchScriptModule/component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from kfp.components import create_component_from_func, InputPath, OutputPath | ||
|
||
|
||
def convert_to_onnx_from_pytorch_script_module( | ||
model_path: InputPath('PyTorchScriptModule'), | ||
converted_model_path: OutputPath('OnnxModel'), | ||
list_of_input_shapes: list, | ||
): | ||
'''Creates fully-connected network in PyTorch ScriptModule format''' | ||
import torch | ||
model = torch.jit.load(model_path) | ||
example_inputs = [ | ||
torch.ones(*input_shape) | ||
for input_shape in list_of_input_shapes | ||
] | ||
example_outputs = model.forward(*example_inputs) | ||
torch.onnx.export( | ||
model=model, | ||
args=example_inputs, | ||
f=converted_model_path, | ||
verbose=True, | ||
training=torch.onnx.TrainingMode.EVAL, | ||
example_outputs=example_outputs, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
convert_to_onnx_from_pytorch_script_module_op = create_component_from_func( | ||
convert_to_onnx_from_pytorch_script_module, | ||
output_component_file='component.yaml', | ||
base_image='pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime', | ||
packages_to_install=[], | ||
annotations={ | ||
"author": "Alexey Volkov <alexey.volkov@ark-kun.com>", | ||
}, | ||
) |
63 changes: 63 additions & 0 deletions
63
components/PyTorch/Convert_to_OnnxModel_from_PyTorchScriptModule/component.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
name: Convert to onnx from pytorch script module | ||
description: Creates fully-connected network in PyTorch ScriptModule format | ||
metadata: | ||
annotations: {author: Alexey Volkov <alexey.volkov@ark-kun.com>} | ||
inputs: | ||
- {name: model, type: PyTorchScriptModule} | ||
- {name: list_of_input_shapes, type: JsonArray} | ||
outputs: | ||
- {name: converted_model, type: OnnxModel} | ||
implementation: | ||
container: | ||
image: pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime | ||
command: | ||
- sh | ||
- -ec | ||
- | | ||
program_path=$(mktemp) | ||
printf "%s" "$0" > "$program_path" | ||
python3 -u "$program_path" "$@" | ||
- | | ||
def _make_parent_dirs_and_return_path(file_path: str): | ||
import os | ||
os.makedirs(os.path.dirname(file_path), exist_ok=True) | ||
return file_path | ||
def convert_to_onnx_from_pytorch_script_module( | ||
model_path, | ||
converted_model_path, | ||
list_of_input_shapes, | ||
): | ||
'''Creates fully-connected network in PyTorch ScriptModule format''' | ||
import torch | ||
model = torch.jit.load(model_path) | ||
example_inputs = [ | ||
torch.ones(*input_shape) | ||
for input_shape in list_of_input_shapes | ||
] | ||
example_outputs = model.forward(*example_inputs) | ||
torch.onnx.export( | ||
model=model, | ||
args=example_inputs, | ||
f=converted_model_path, | ||
verbose=True, | ||
training=torch.onnx.TrainingMode.EVAL, | ||
example_outputs=example_outputs, | ||
) | ||
import json | ||
import argparse | ||
_parser = argparse.ArgumentParser(prog='Convert to onnx from pytorch script module', description='Creates fully-connected network in PyTorch ScriptModule format') | ||
_parser.add_argument("--model", dest="model_path", type=str, required=True, default=argparse.SUPPRESS) | ||
_parser.add_argument("--list-of-input-shapes", dest="list_of_input_shapes", type=json.loads, required=True, default=argparse.SUPPRESS) | ||
_parser.add_argument("--converted-model", dest="converted_model_path", type=_make_parent_dirs_and_return_path, required=True, default=argparse.SUPPRESS) | ||
_parsed_args = vars(_parser.parse_args()) | ||
_outputs = convert_to_onnx_from_pytorch_script_module(**_parsed_args) | ||
args: | ||
- --model | ||
- {inputPath: model} | ||
- --list-of-input-shapes | ||
- {inputValue: list_of_input_shapes} | ||
- --converted-model | ||
- {outputPath: converted_model} |