Skip to content

Commit

Permalink
add switch for onnx exporter (#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
drcut authored Sep 18, 2020
1 parent 2114356 commit cc332b2
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions mmcv/onnx/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Modified from https://github.com/pytorch/pytorch."""
import os

import numpy as np
import torch
from torch.nn.modules.utils import _pair, _single, _triple
Expand All @@ -21,14 +23,27 @@ def symbolic_fn(g, input, output_size, *args):
'Constant', value_t=torch.tensor([], dtype=torch.float32))

if scales is None:
input_size = g.op('Shape', input)
input_size_beg = sym_help._slice_helper(
g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op(
'Cast',
output_size,
to_i=sym_help.cast_pytorch_to_onnx['Long'])
output_size = g.op('Concat', input_size_beg, output_size, axis_i=0)
if 'ONNX_BACKEND' in os.environ and os.environ[
'ONNX_BACKEND'] == 'TensorRT':
input_size = input.type().sizes()
# slice the first two dim
input_size = input_size[:2]
# convert output_size to int type
output_size = sym_help._maybe_get_const(output_size, 'is')
input_size.extend(output_size)
output_size = g.op(
'Constant',
value_t=torch.tensor(input_size, dtype=torch.int64))
else:
input_size = g.op('Shape', input)
input_size_beg = sym_help._slice_helper(
g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op(
'Cast',
output_size,
to_i=sym_help.cast_pytorch_to_onnx['Long'])
output_size = g.op(
'Concat', input_size_beg, output_size, axis_i=0)
scales = g.op(
'Constant', value_t=torch.tensor([], dtype=torch.float32))
return g.op(
Expand Down

0 comments on commit cc332b2

Please sign in to comment.