Skip to content

Commit 43f0a24

Browse files
Extend TensorComputeOp to allow scalar inputs (#2606).
1 parent 605b5e6 commit 43f0a24

File tree

8 files changed

+97
-12
lines changed

8 files changed

+97
-12
lines changed

include/tvm/operation.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
286286
Array<Tensor> inputs;
287287
/*! \brief region of input tensors */
288288
Array<Region> input_regions;
289+
/*! \brief scalar expression inputs */
290+
Array<Expr> scalar_inputs;
289291
/*! \brief constructor */
290292
TensorComputeOpNode() {}
291293
// override functions
@@ -314,6 +316,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
314316
v->Visit("intrin", &intrin);
315317
v->Visit("inputs", &inputs);
316318
v->Visit("input_regions", &input_regions);
319+
v->Visit("scalar_inputs", &scalar_inputs);
317320
}
318321
static Operation make(std::string name,
319322
std::string tag,
@@ -322,7 +325,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
322325
int schedulable_ndim,
323326
TensorIntrin intrin,
324327
Array<Tensor> tensors,
325-
Array<Region> regions);
328+
Array<Region> regions,
329+
Array<Expr> scalar_inputs);
326330

327331
static constexpr const char* _type_key = "TensorComputeOp";
328332
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);

include/tvm/tensor_intrin.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ class TensorIntrinNode : public Node {
6767
* When it is a constant, it means we can only take data in that shape.
6868
*/
6969
Array<Buffer> buffers;
70+
/*! \brief List of scalar variables, used in body. These placeholders
71+
* will be bound to expressions passed in when the TensorIntrin is called
72+
* from a TensorComputeOp.
73+
*/
74+
Array<Var> scalar_params;
7075
/*! \brief The normal statement to execute the intrinsic */
7176
Stmt body;
7277
/*!
@@ -87,6 +92,7 @@ class TensorIntrinNode : public Node {
8792
v->Visit("op", &op);
8893
v->Visit("inputs", &inputs);
8994
v->Visit("buffers", &buffers);
95+
v->Visit("scalar_params", &scalar_params);
9096
v->Visit("body", &body);
9197
v->Visit("reduce_init", &reduce_init);
9298
v->Visit("reduce_update", &reduce_update);
@@ -96,6 +102,7 @@ class TensorIntrinNode : public Node {
96102
Operation op,
97103
Array<Tensor> inputs,
98104
Array<Buffer> buffers,
105+
Array<Var> scalar_params,
99106
Stmt body,
100107
Stmt reduce_init,
101108
Stmt reduce_update);
@@ -134,22 +141,29 @@ class TensorIntrinCallNode : public Node {
134141
Array<Tensor> tensors;
135142
/*! \brief regions of input tensors */
136143
Array<Region> regions;
144+
145+
137146
/*!
138147
* \brief IterVar on each reduction axis, if the
139148
* intrin will use the reduce axis
140149
*/
141150
Array<IterVar> reduce_axis;
142151

152+
/*! \brief scalar expression inputs */
153+
Array<Expr> scalar_inputs;
154+
143155
void VisitAttrs(AttrVisitor* v) final {
144156
v->Visit("intrin", &intrin);
145157
v->Visit("tensors", &tensors);
146158
v->Visit("regions", &regions);
147159
v->Visit("reduce_axis", &reduce_axis);
160+
v->Visit("scalar_inputs", &scalar_inputs);
148161
}
149162
static TensorIntrinCall make(TensorIntrin intrin,
150163
Array<Tensor> tensors,
151164
Array<Region> regions,
152-
Array<IterVar> reduce_axis);
165+
Array<IterVar> reduce_axis,
166+
Array<Expr> scalar_inputs);
153167

154168
static constexpr const char* _type_key = "TensorIntrinCall";
155169
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);

python/tvm/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
319319
out_ndim,
320320
body.intrin,
321321
body.tensors,
322-
body.regions)
322+
body.regions,
323+
body.scalar_inputs)
323324
else:
324325
if not isinstance(body, (list, tuple)):
325326
body = [body]

python/tvm/tensor_intrin.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,23 @@ class TensorIntrin(NodeBase):
5050
decl_tensor_intrin: Construct a TensorIntrin
5151
"""
5252
def __call__(self, *args, **kwargs):
53-
tensors = [x.tensor for x in args]
54-
regions = [_get_region(x) for x in args]
53+
tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
54+
scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
55+
regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
5556
reduce_axis = []
5657
if "reduce_axis" in kwargs:
5758
reduce_axis = kwargs["reduce_axis"]
5859
if not isinstance(reduce_axis, (list, tuple)):
5960
reduce_axis = [reduce_axis]
6061
reduce_axis = _api.convert(reduce_axis)
61-
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
62+
if len(scalar_inputs) > 0:
63+
scalar_inputs = _api.convert(scalar_inputs)
64+
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
6265

6366
def decl_tensor_intrin(op,
6467
fcompute,
6568
name="tensor_intrin",
66-
binds=None):
69+
binds=None, scalar_params=None):
6770
"""Declare a tensor intrinsic function.
6871
6972
Parameters
@@ -96,6 +99,9 @@ def decl_tensor_intrin(op,
9699
requirement of the function. By default, a new compact buffer is created
97100
for each tensor in the argument.
98101
102+
scalar_params: a list of variables used by op, whose values will be passed
103+
as scalar_inputs when the tensor intrinsic is called.
104+
99105
Returns
100106
-------
101107
intrin: TensorIntrin
@@ -122,11 +128,15 @@ def decl_tensor_intrin(op,
122128
offset_factor=cfg.offset_factor))
123129
binds_list.append(buf)
124130

125-
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
131+
if scalar_params:
132+
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
133+
else:
134+
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
135+
scalar_params = []
126136
if isinstance(body, (_expr.Expr, _stmt.Stmt)):
127137
body = [body]
128138
body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
129139
if len(body) < 3:
130140
body += [None] * (3 - len(body))
131141
return _api_internal._TensorIntrin(
132-
name, op, inputs, binds_list, *body)
142+
name, op, inputs, binds_list, scalar_params, *body)

src/lang/tensor.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
8383
Operation op,
8484
Array<Tensor> inputs,
8585
Array<Buffer> buffers,
86+
Array<Var> scalar_params,
8687
Stmt body,
8788
Stmt reduce_init,
8889
Stmt reduce_update) {
@@ -91,6 +92,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
9192
n->op = std::move(op);
9293
n->inputs = std::move(inputs);
9394
n->buffers = std::move(buffers);
95+
n->scalar_params = std::move(scalar_params);
9496
n->body = std::move(body);
9597
n->reduce_init = std::move(reduce_init);
9698
n->reduce_update = std::move(reduce_update);
@@ -110,12 +112,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
110112
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
111113
Array<Tensor> tensors,
112114
Array<Region> regions,
113-
Array<IterVar> reduce_axis) {
115+
Array<IterVar> reduce_axis,
116+
Array<Expr> scalar_inputs) {
114117
auto n = make_node<TensorIntrinCallNode>();
115118
n->intrin = std::move(intrin);
116119
n->tensors = std::move(tensors);
117120
n->regions = std::move(regions);
118121
n->reduce_axis = std::move(reduce_axis);
122+
n->scalar_inputs = std::move(scalar_inputs);
119123
return TensorIntrinCall(n);
120124
}
121125

src/op/tensor_compute_op.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ Operation TensorComputeOpNode::make(std::string name,
5858
int schedulable_ndim,
5959
TensorIntrin intrin,
6060
Array<Tensor> tensors,
61-
Array<Region> regions) {
61+
Array<Region> regions,
62+
Array<Expr> scalar_inputs) {
6263
auto n = make_node<TensorComputeOpNode>();
6364
n->name = std::move(name);
6465
n->tag = std::move(tag);
@@ -68,6 +69,7 @@ Operation TensorComputeOpNode::make(std::string name,
6869
n->intrin = std::move(intrin);
6970
n->inputs = std::move(tensors);
7071
n->input_regions = std::move(regions);
72+
n->scalar_inputs = std::move(scalar_inputs);
7173
return Operation(n);
7274
}
7375

@@ -184,6 +186,19 @@ Stmt TensorComputeOpNode::BuildProvide(
184186
std::unordered_map<const Variable*, Expr> vmap;
185187
ir::ArgBinder binder(&vmap);
186188

189+
// Map the expressions passed in the call to the TensorIntrin, to the placeholder
190+
// variables
191+
Array<Expr> user_expr = this->scalar_inputs;
192+
Array<Var> scalar_params = this->intrin->scalar_params;
193+
Array<Expr> sp_expr;
194+
for (auto sp : scalar_params) {
195+
Expr esp = sp;
196+
sp_expr.push_back(esp);
197+
}
198+
CHECK_EQ(sp_expr.size(), user_expr.size());
199+
// TODO(jdavies-huawei): what name should be used here?
200+
binder.BindArray(sp_expr, user_expr, this->name);
201+
187202
size_t tloc = stage->leaf_iter_vars.size();
188203
ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
189204

src/schedule/schedule_dataflow_rewrite.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,15 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
410410
new_regions.push_back(region);
411411
}
412412

413+
Array<Expr> new_scalar_inputs;
414+
for (Expr old_input : tensor_op->scalar_inputs) {
415+
new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
416+
}
417+
413418
Operation cache_op = TensorComputeOpNode::make(
414419
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
415420
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
416-
tensor_op->intrin, tensor_op->inputs, new_regions);
421+
tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
417422

418423
// axis will be used in generating compute op
419424
Array<IterVar> compute_axis = tensor_op->axis;

tests/python/unittest/test_lang_schedule.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,43 @@ def intrin_func(ins, outs):
209209
assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
210210
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
211211

212+
def test_tensor_intrin_scalar_params():
213+
n = tvm.var("n")
214+
x = tvm.placeholder((n,), name='x')
215+
v = tvm.var("v")
216+
w = tvm.var("w")
217+
z = tvm.compute((n,), lambda i: x[i]*v + w, name='z')
218+
219+
def intrin_func(ins, outs, sp):
220+
assert(isinstance(ins[0], tvm.schedule.Buffer))
221+
assert(ins[0].shape[0] == n)
222+
assert(sp[0] == v)
223+
assert(sp[1] == w)
224+
return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])
225+
226+
with tvm.build_config(offset_factor=1):
227+
intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
228+
assert intrin.op == z.op
229+
assert intrin.reduce_init is None
230+
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
231+
assert(intrin.buffers[0].shape[0] == n)
232+
assert tuple(intrin.scalar_params) == tuple((v, w))
233+
234+
A = tvm.placeholder((10,10), name='A')
235+
# Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
236+
C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
237+
s = tvm.create_schedule(C.op)
238+
stmt = tvm.lower(s, [A, C], simple_mode=True)
239+
assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate)
240+
assert len(stmt.body.body.body.value.args) == 5
241+
assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
242+
assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
212243

213244
if __name__ == "__main__":
214245
test_singleton()
215246
test_pragma()
216247
test_tensor_intrin()
248+
test_tensor_intrin_scalar_params()
217249
test_rfactor()
218250
test_schedule_create()
219251
test_reorder()

0 commit comments

Comments
 (0)