Skip to content

Commit 6761072

Browse files
committed
Remove duplicate map as per Jenny's comments
1 parent 95529f5 commit 6761072

File tree

2 files changed

+24
-159
lines changed

2 files changed

+24
-159
lines changed

nnvm/python/nnvm/frontend/onnx.py

Lines changed: 12 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -775,80 +775,6 @@ def _get_convert_map(opset):
775775
}
776776

777777

778-
supported_ops = set([
779-
'Constant',
780-
'Identity',
781-
'ThresholdedRelu',
782-
'ScaledTanh',
783-
'ParametricSoftplus',
784-
'ConstantFill',
785-
'FC',
786-
'Scale',
787-
'ImageScaler',
788-
'Upsample' ,
789-
'SpatialBN',
790-
'Add',
791-
'Sub',
792-
'Mul',
793-
'Div',
794-
'Neg',
795-
'Abs',
796-
'Reciprocal',
797-
'Floor',
798-
'Ceil',
799-
'Sqrt',
800-
'Relu',
801-
'LeakyRelu',
802-
'Selu',
803-
'Elu',
804-
'Exp',
805-
'Log',
806-
'Tanh',
807-
'Pow',
808-
'PRelu',
809-
'Sigmoid',
810-
'HardSigmoid',
811-
'Max',
812-
'Min',
813-
'Sum',
814-
'Mean',
815-
'Clip',
816-
'Softmax',
817-
'LogSoftmax',
818-
'Softsign',
819-
'SoftPlus',
820-
'Gemm',
821-
'MatMul',
822-
'AveragePool',
823-
'MaxPool',
824-
'Conv',
825-
'ConvTranspose',
826-
'GlobalAveragePool',
827-
'GlobalMaxPool',
828-
'BatchNormalization',
829-
'Dropout',
830-
'Flatten',
831-
'LRN',
832-
'ReduceMax',
833-
'ReduceMin',
834-
'ReduceSum',
835-
'ReduceMean',
836-
'ArgMax',
837-
'ArgMin',
838-
'Cast',
839-
'Reshape',
840-
'Concat',
841-
'Split',
842-
'Slice',
843-
'Transpose',
844-
'Gather',
845-
'Squeeze',
846-
'Unsqueeze',
847-
'Pad',
848-
'Shape',
849-
])
850-
851-
852778
class GraphProto(object):
853779
"""A helper class for handling nnvm graph copying from pb2.GraphProto.
854780
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
@@ -899,15 +825,21 @@ def from_onnx(self, graph, opset):
899825
self._num_input += 1
900826
self._nodes[i_name] = _sym.Variable(name=i_name)
901827
# get list of unsupported ops
902-
unsupported_ops = []
828+
convert_map = _get_convert_map(opset)
829+
unsupported_ops = set()
903830
for node in graph.node:
904831
op_name = node.op_type
905-
if op_name not in supported_ops:
906-
unsupported_ops.append(op_name)
832+
if op_name not in convert_map and \
833+
op_name != 'Constant' and \
834+
op_name not in _identity_list:
835+
unsupported_ops.add(op_name)
907836
if unsupported_ops:
908-
msg = 'The following operators are not supported for frontend ONNX: {}'
909-
unsupported_ops = str(unsupported_ops).strip('[]').replace("'", '')
910-
raise tvm.error.OpNotImplemented(msg.format(unsupported_ops))
837+
msg = ['The following operators are not supported for frontend ONNX: ']
838+
for i, op_name in enumerate(unsupported_ops):
839+
msg.append(op_name)
840+
if i != len(unsupported_ops) - 1:
841+
msg.append(', ')
842+
raise tvm.error.OpNotImplemented(''.join(msg))
911843
# construct nodes, nodes are stored as directed acyclic graph
912844
for node in graph.node:
913845
op_name = node.op_type

python/tvm/relay/frontend/onnx.py

Lines changed: 12 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -879,79 +879,6 @@ def _get_convert_map(opset):
879879
'Shape': Shape.get_converter(opset),
880880
}
881881

882-
supported_ops = set([
883-
'Constant',
884-
'Identity',
885-
'ThresholdedRelu',
886-
'ScaledTanh',
887-
'ParametricSoftplus',
888-
'ConstantFill',
889-
'FC',
890-
'Scale',
891-
'Upsample' ,
892-
'SpatialBN',
893-
'Add',
894-
'Sub',
895-
'Mul',
896-
'Div',
897-
'Neg',
898-
'Abs',
899-
'Reciprocal',
900-
'Floor',
901-
'Ceil',
902-
'Sqrt',
903-
'Relu',
904-
'LeakyRelu',
905-
'Selu',
906-
'Elu',
907-
'Exp',
908-
'Log',
909-
'Tanh',
910-
'Pow',
911-
'PRelu',
912-
'Sigmoid',
913-
'HardSigmoid',
914-
'Max',
915-
'Min',
916-
'Sum',
917-
'Mean',
918-
'Clip',
919-
'Softmax',
920-
'LogSoftmax',
921-
'Softsign',
922-
'SoftPlus',
923-
'Gemm',
924-
'MatMul',
925-
'AveragePool',
926-
'MaxPool',
927-
'Conv',
928-
'ConvTranspose',
929-
'GlobalAveragePool',
930-
'GlobalMaxPool',
931-
'BatchNormalization',
932-
'Dropout',
933-
'Flatten',
934-
'LRN',
935-
'ReduceMax',
936-
'ReduceMin',
937-
'ReduceSum',
938-
'ReduceMean',
939-
'ReduceProd',
940-
'ArgMax',
941-
'ArgMin',
942-
'Cast',
943-
'Reshape',
944-
'Concat',
945-
'Split',
946-
'Slice',
947-
'Transpose',
948-
'Gather',
949-
'Squeeze',
950-
'Unsqueeze',
951-
'Pad',
952-
'Shape',
953-
])
954-
955882

956883
class GraphProto(object):
957884
"""A helper class for handling Relay expression copying from pb2.GraphProto.
@@ -1025,15 +952,21 @@ def from_onnx(self, graph, opset):
1025952
dtype = d_type
1026953
self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
1027954
# get list of unsupported ops
1028-
unsupported_ops = []
955+
convert_map = _get_convert_map(opset)
956+
unsupported_ops = set()
1029957
for node in graph.node:
1030958
op_name = node.op_type
1031-
if op_name not in supported_ops:
1032-
unsupported_ops.append(op_name)
959+
if op_name not in convert_map and \
960+
op_name != 'Constant' and \
961+
op_name not in _identity_list:
962+
unsupported_ops.add(op_name)
1033963
if unsupported_ops:
1034-
unsupported_ops = str(unsupported_ops).strip('[]').replace("'", '')
1035-
msg = 'The following operators are not supported for frontend ONNX: {}'
1036-
raise tvm.error.OpNotImplemented(msg.format(unsupported_ops))
964+
msg = ['The following operators are not supported for frontend ONNX: ']
965+
for i, op_name in enumerate(unsupported_ops):
966+
msg.append(op_name)
967+
if i != len(unsupported_ops) - 1:
968+
msg.append(', ')
969+
raise tvm.error.OpNotImplemented(''.join(msg))
1037970
# construct nodes, nodes are stored as directed acyclic graph
1038971
for node in graph.node:
1039972
op_name = node.op_type

0 commit comments

Comments
 (0)