@@ -50,20 +50,23 @@ class TensorIntrin(NodeBase):
50
50
decl_tensor_intrin: Construct a TensorIntrin
51
51
"""
52
52
def __call__ (self , * args , ** kwargs ):
53
- tensors = [x .tensor for x in args ]
54
- regions = [_get_region (x ) for x in args ]
53
+ tensors = [x .tensor for x in args if isinstance (x , _tensor .TensorSlice )]
54
+ scalar_inputs = [x for x in args if not isinstance (x , _tensor .TensorSlice )]
55
+ regions = [_get_region (x ) for x in args if isinstance (x , _tensor .TensorSlice )]
55
56
reduce_axis = []
56
57
if "reduce_axis" in kwargs :
57
58
reduce_axis = kwargs ["reduce_axis" ]
58
59
if not isinstance (reduce_axis , (list , tuple )):
59
60
reduce_axis = [reduce_axis ]
60
61
reduce_axis = _api .convert (reduce_axis )
61
- return _api_internal ._TensorIntrinCall (self , tensors , regions , reduce_axis )
62
+ if len (scalar_inputs ) > 0 :
63
+ scalar_inputs = _api .convert (scalar_inputs )
64
+ return _api_internal ._TensorIntrinCall (self , tensors , regions , reduce_axis , scalar_inputs )
62
65
63
66
def decl_tensor_intrin (op ,
64
67
fcompute ,
65
68
name = "tensor_intrin" ,
66
- binds = None ):
69
+ binds = None , scalar_params = None ):
67
70
"""Declare a tensor intrinsic function.
68
71
69
72
Parameters
@@ -96,6 +99,9 @@ def decl_tensor_intrin(op,
96
99
requirement of the function. By default, a new compact buffer is created
97
100
for each tensor in the argument.
98
101
102
+ scalar_params: a list of variables used by op, whose values will be passed
103
+ as scalar_inputs when the tensor intrinsic is called.
104
+
99
105
Returns
100
106
-------
101
107
intrin: TensorIntrin
@@ -122,11 +128,15 @@ def decl_tensor_intrin(op,
122
128
offset_factor = cfg .offset_factor ))
123
129
binds_list .append (buf )
124
130
125
- body = fcompute (binds_list [:len (inputs )], binds_list [len (inputs ):])
131
+ if scalar_params :
132
+ body = fcompute (binds_list [:len (inputs )], binds_list [len (inputs ):], scalar_params )
133
+ else :
134
+ body = fcompute (binds_list [:len (inputs )], binds_list [len (inputs ):])
135
+ scalar_params = []
126
136
if isinstance (body , (_expr .Expr , _stmt .Stmt )):
127
137
body = [body ]
128
138
body = [_make .Evaluate (x ) if isinstance (x , _expr .Expr ) else x for x in body ]
129
139
if len (body ) < 3 :
130
140
body += [None ] * (3 - len (body ))
131
141
return _api_internal ._TensorIntrin (
132
- name , op , inputs , binds_list , * body )
142
+ name , op , inputs , binds_list , scalar_params , * body )
0 commit comments