2020from tvm .ir import Attrs
2121
2222from . import _make
23- from typing import List , Tuple
24- from tvm .custom_operation_config import (
25- CustomOpConfigInfo , CustomOperationConfig
26- )
23+ from typing import List
24+ from tvm .custom_operation_config import CustomOperationConfig
2725import tvm ._ffi
2826import json
2927
3028
3129MAX_TENSOR_INPUTS = 5
3230
3331
32+ def make_custom_op_from_attrs_str (tensor_inputs : List [expr .ExprWithOp ],
33+ custom_op_attr_str : str ):
34+
35+ if len (tensor_inputs ) == 1 :
36+ return _make .custom_op_1 (* tensor_inputs , custom_op_attr_str )
37+ elif len (tensor_inputs ) == 2 :
38+ return _make .custom_op_2 (* tensor_inputs , custom_op_attr_str )
39+ elif len (tensor_inputs ) == 3 :
40+ return _make .custom_op_3 (* tensor_inputs , custom_op_attr_str )
41+ elif len (tensor_inputs ) == 4 :
42+ return _make .custom_op_4 (* tensor_inputs , custom_op_attr_str )
43+ elif len (tensor_inputs ) == 5 :
44+ return _make .custom_op_5 (* tensor_inputs , custom_op_attr_str )
45+ else :
46+ msg = "Unsupported number of input tensor arguments (%d)." % (len (tensor_inputs ))
47+ raise AssertionError (msg )
48+
49+
3450def custom_op (inputs , input_types , name , code , func_name , datatype , compiler_flags ):
3551 """
3652 Create a Relay IR node for the custom operation. Specifically, a
@@ -82,20 +98,7 @@ def custom_op(inputs, input_types, name, code, func_name, datatype, compiler_fla
8298 }
8399
84100 custom_op_attr_str = json .dumps (custom_op_attrs )
85-
86- if len (tensor_inputs ) == 1 :
87- return _make .custom_op_1 (* tensor_inputs , custom_op_attr_str )
88- elif len (tensor_inputs ) == 2 :
89- return _make .custom_op_2 (* tensor_inputs , custom_op_attr_str )
90- elif len (tensor_inputs ) == 3 :
91- return _make .custom_op_3 (* tensor_inputs , custom_op_attr_str )
92- elif len (tensor_inputs ) == 4 :
93- return _make .custom_op_4 (* tensor_inputs , custom_op_attr_str )
94- elif len (tensor_inputs ) == 5 :
95- return _make .custom_op_5 (* tensor_inputs , custom_op_attr_str )
96- else :
97- msg = "Unsupported number of input tensor arguments (%d)." % (len (tensor_inputs ))
98- raise AssertionError (msg )
101+ return make_custom_op_from_attrs_str (tensor_inputs , custom_op_attr_str )
99102
100103
101104def is_valid_attribute (input ):
@@ -110,7 +113,7 @@ def is_valid_attribute(input):
110113
111114 if input_type == list and type (input [0 ]) in [int , float ]:
112115 for elem in input :
113- if type (elem ) != type (input [0 ]):
116+ if not isinstance (elem , type (input [0 ]) ):
114117 return False
115118 return True
116119
0 commit comments