forked from open-mmlab/mmagic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch2onnx.py
218 lines (194 loc) · 7.79 KB
/
pytorch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import warnings
import cv2
import mmcv
import numpy as np
import onnx
import onnxruntime as rt
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint
from mmedit.datasets.pipelines import Compose
from mmedit.models import build_model
def pytorch2onnx(model,
input,
model_type,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False,
dynamic_export=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.
Args:
model (nn.Module): Pytorch model we want to export.
input (dict): We need to use this input to execute the model.
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
"""
model.cpu().eval()
if model_type == 'mattor':
merged = input['merged'].unsqueeze(0)
trimap = input['trimap'].unsqueeze(0)
data = torch.cat((merged, trimap), 1)
elif model_type == 'restorer':
data = input['lq'].unsqueeze(0)
model.forward = model.forward_dummy
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
dynamic_axes = None
if dynamic_export:
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch',
2: 'height',
3: 'width'
}
}
with torch.no_grad():
torch.onnx.export(
model,
data,
output_file,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=show,
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
if dynamic_export:
# scale image for dynamic shape test
data = torch.nn.functional.interpolate(data, scale_factor=1.1)
# concate flip image for batch test
flip_data = data.flip(-1)
data = torch.cat((data, flip_data), 0)
# get pytorch output, only concern pred_alpha
with torch.no_grad():
pytorch_result = model(data)
if isinstance(pytorch_result, (tuple, list)):
pytorch_result = pytorch_result[0]
pytorch_result = pytorch_result.detach().numpy()
# get onnx output
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(None, {
'input': data.detach().numpy(),
})
# only concern pred_alpha value
if isinstance(onnx_result, (tuple, list)):
onnx_result = onnx_result[0]
if show:
pytorch_visualize = pytorch_result[0].transpose(1, 2, 0)
pytorch_visualize = np.clip(pytorch_visualize, 0, 1)[:, :, ::-1]
onnx_visualize = onnx_result[0].transpose(1, 2, 0)
onnx_visualize = np.clip(onnx_visualize, 0, 1)[:, :, ::-1]
cv2.imshow('PyTorch', pytorch_visualize)
cv2.imshow('ONNXRuntime', onnx_visualize)
cv2.waitKey()
# check the numerical value
assert np.allclose(
pytorch_result, onnx_result, rtol=1e-5,
atol=1e-5), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(description='Convert MMediting to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'model_type',
help='what kind of model the config belong to.',
choices=['inpainting', 'mattor', 'restorer', 'synthesizer'])
parser.add_argument('img_path', help='path to input image file')
parser.add_argument(
'--trimap-path',
default=None,
help='path to input trimap file, used in mattor model')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether to export onnx with dynamic axis.')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model_type = args.model_type
if model_type == 'mattor' and args.trimap_path is None:
raise ValueError('Please set `--trimap-path` to convert mattor model.')
assert args.opset_version == 11, 'MMEditing only support opset 11 now'
config = mmcv.Config.fromfile(args.config)
config.model.pretrained = None
# ONNX does not support spectral norm
if model_type == 'mattor':
if hasattr(config.model.backbone.encoder, 'with_spectral_norm'):
config.model.backbone.encoder.with_spectral_norm = False
config.model.backbone.decoder.with_spectral_norm = False
config.test_cfg.metrics = None
# build the model
model = build_model(config.model, test_cfg=config.test_cfg)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# remove alpha from test_pipeline
if model_type == 'mattor':
keys_to_remove = ['alpha', 'ori_alpha']
elif model_type == 'restorer':
keys_to_remove = ['gt', 'gt_path']
for key in keys_to_remove:
for pipeline in list(config.test_pipeline):
if 'key' in pipeline and key == pipeline['key']:
config.test_pipeline.remove(pipeline)
if 'keys' in pipeline and key in pipeline['keys']:
pipeline['keys'].remove(key)
if len(pipeline['keys']) == 0:
config.test_pipeline.remove(pipeline)
if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
pipeline['meta_keys'].remove(key)
# build the data pipeline
test_pipeline = Compose(config.test_pipeline)
# prepare data
if model_type == 'mattor':
data = dict(merged_path=args.img_path, trimap_path=args.trimap_path)
elif model_type == 'restorer':
data = dict(lq_path=args.img_path)
data = test_pipeline(data)
# convert model to onnx file
pytorch2onnx(
model,
data,
model_type,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify,
dynamic_export=args.dynamic_export)
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)