Skip to content

Commit a1ecbc8

Browse files
Added make_custom_op_from_attrs_str function.
1 parent 0ada22f commit a1ecbc8

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

python/tvm/relay/op/nn/custom_operation.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,33 @@
2020
from tvm.ir import Attrs
2121

2222
from . import _make
23-
from typing import List, Tuple
24-
from tvm.custom_operation_config import (
25-
CustomOpConfigInfo, CustomOperationConfig
26-
)
23+
from typing import List
24+
from tvm.custom_operation_config import CustomOperationConfig
2725
import tvm._ffi
2826
import json
2927

3028

3129
MAX_TENSOR_INPUTS = 5
3230

3331

32+
def make_custom_op_from_attrs_str(tensor_inputs: List[expr.ExprWithOp],
33+
custom_op_attr_str: str):
34+
35+
if len(tensor_inputs) == 1:
36+
return _make.custom_op_1(*tensor_inputs, custom_op_attr_str)
37+
elif len(tensor_inputs) == 2:
38+
return _make.custom_op_2(*tensor_inputs, custom_op_attr_str)
39+
elif len(tensor_inputs) == 3:
40+
return _make.custom_op_3(*tensor_inputs, custom_op_attr_str)
41+
elif len(tensor_inputs) == 4:
42+
return _make.custom_op_4(*tensor_inputs, custom_op_attr_str)
43+
elif len(tensor_inputs) == 5:
44+
return _make.custom_op_5(*tensor_inputs, custom_op_attr_str)
45+
else:
46+
msg = "Unsupported number of input tensor arguments (%d)." % (len(tensor_inputs))
47+
raise AssertionError(msg)
48+
49+
3450
def custom_op(inputs, input_types, name, code, func_name, datatype, compiler_flags):
3551
"""
3652
Create a Relay IR node for the custom operation. Specifically, a
@@ -82,20 +98,7 @@ def custom_op(inputs, input_types, name, code, func_name, datatype, compiler_fla
8298
}
8399

84100
custom_op_attr_str = json.dumps(custom_op_attrs)
85-
86-
if len(tensor_inputs) == 1:
87-
return _make.custom_op_1(*tensor_inputs, custom_op_attr_str)
88-
elif len(tensor_inputs) == 2:
89-
return _make.custom_op_2(*tensor_inputs, custom_op_attr_str)
90-
elif len(tensor_inputs) == 3:
91-
return _make.custom_op_3(*tensor_inputs, custom_op_attr_str)
92-
elif len(tensor_inputs) == 4:
93-
return _make.custom_op_4(*tensor_inputs, custom_op_attr_str)
94-
elif len(tensor_inputs) == 5:
95-
return _make.custom_op_5(*tensor_inputs, custom_op_attr_str)
96-
else:
97-
msg = "Unsupported number of input tensor arguments (%d)." % (len(tensor_inputs))
98-
raise AssertionError(msg)
101+
return make_custom_op_from_attrs_str(tensor_inputs, custom_op_attr_str)
99102

100103

101104
def is_valid_attribute(input):
@@ -110,7 +113,7 @@ def is_valid_attribute(input):
110113

111114
if input_type == list and type(input[0]) in [int, float]:
112115
for elem in input:
113-
if type(elem) != type(input[0]):
116+
if not isinstance(elem, type(input[0])):
114117
return False
115118
return True
116119

0 commit comments

Comments
 (0)