Skip to content

Commit

Permalink
feat(components): PyTorch - Convert to ONNX from PyTorch ScriptModule (
Browse files Browse the repository at this point in the history
  • Loading branch information
Ark-kun authored Mar 8, 2021
1 parent dfa7563 commit fc7afdd
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from kfp.components import create_component_from_func, InputPath, OutputPath


def convert_to_onnx_from_pytorch_script_module(
model_path: InputPath('PyTorchScriptModule'),
converted_model_path: OutputPath('OnnxModel'),
list_of_input_shapes: list,
):
'''Creates fully-connected network in PyTorch ScriptModule format'''
import torch
model = torch.jit.load(model_path)
example_inputs = [
torch.ones(*input_shape)
for input_shape in list_of_input_shapes
]
example_outputs = model.forward(*example_inputs)
torch.onnx.export(
model=model,
args=example_inputs,
f=converted_model_path,
verbose=True,
training=torch.onnx.TrainingMode.EVAL,
example_outputs=example_outputs,
)


if __name__ == '__main__':
convert_to_onnx_from_pytorch_script_module_op = create_component_from_func(
convert_to_onnx_from_pytorch_script_module,
output_component_file='component.yaml',
base_image='pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime',
packages_to_install=[],
annotations={
"author": "Alexey Volkov <alexey.volkov@ark-kun.com>",
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: Convert to onnx from pytorch script module
description: Creates fully-connected network in PyTorch ScriptModule format
metadata:
annotations: {author: Alexey Volkov <alexey.volkov@ark-kun.com>}
inputs:
- {name: model, type: PyTorchScriptModule}
- {name: list_of_input_shapes, type: JsonArray}
outputs:
- {name: converted_model, type: OnnxModel}
implementation:
container:
image: pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
command:
- sh
- -ec
- |
program_path=$(mktemp)
printf "%s" "$0" > "$program_path"
python3 -u "$program_path" "$@"
- |
def _make_parent_dirs_and_return_path(file_path: str):
import os
os.makedirs(os.path.dirname(file_path), exist_ok=True)
return file_path
def convert_to_onnx_from_pytorch_script_module(
model_path,
converted_model_path,
list_of_input_shapes,
):
'''Creates fully-connected network in PyTorch ScriptModule format'''
import torch
model = torch.jit.load(model_path)
example_inputs = [
torch.ones(*input_shape)
for input_shape in list_of_input_shapes
]
example_outputs = model.forward(*example_inputs)
torch.onnx.export(
model=model,
args=example_inputs,
f=converted_model_path,
verbose=True,
training=torch.onnx.TrainingMode.EVAL,
example_outputs=example_outputs,
)
import json
import argparse
_parser = argparse.ArgumentParser(prog='Convert to onnx from pytorch script module', description='Creates fully-connected network in PyTorch ScriptModule format')
_parser.add_argument("--model", dest="model_path", type=str, required=True, default=argparse.SUPPRESS)
_parser.add_argument("--list-of-input-shapes", dest="list_of_input_shapes", type=json.loads, required=True, default=argparse.SUPPRESS)
_parser.add_argument("--converted-model", dest="converted_model_path", type=_make_parent_dirs_and_return_path, required=True, default=argparse.SUPPRESS)
_parsed_args = vars(_parser.parse_args())
_outputs = convert_to_onnx_from_pytorch_script_module(**_parsed_args)
args:
- --model
- {inputPath: model}
- --list-of-input-shapes
- {inputValue: list_of_input_shapes}
- --converted-model
- {outputPath: converted_model}

0 comments on commit fc7afdd

Please sign in to comment.