Skip to content

Commit 8f4b7d9

Browse files
committed
[LANG] Introduce Scan, Bugfix Canonical
1 parent f8f0282 commit 8f4b7d9

File tree

18 files changed

+773
-113
lines changed

18 files changed

+773
-113
lines changed

include/tvm/ir.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
4949
static constexpr const char* Min = "Min";
5050
};
5151

52-
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
53-
namespace attr {
5452
/*!
55-
* \brief Mark scope of iteration variable, used by Schedule.
53+
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
5654
*/
57-
constexpr const char* scope = "scope";
55+
struct TensorKey {
56+
FunctionRef f;
57+
int value_index;
58+
59+
inline bool operator==(const TensorKey& other) const {
60+
return f == other.f && value_index == other.value_index;
61+
}
62+
inline std::string GetName() const {
63+
if (f->num_outputs() == 1) return f->func_name();
64+
std::ostringstream os;
65+
os << f->func_name() << ".v" << value_index;
66+
return os.str();
67+
}
68+
};
69+
70+
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
71+
namespace attr {
72+
// The above attr does not pass to ir stage.
5873
/*!
5974
* \brief Mark launching extent of thread, used by device API.
6075
*/
@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
189204
} // namespace ir
190205
} // namespace tvm
191206

207+
namespace std {
208+
template <>
209+
struct hash<::tvm::ir::TensorKey> {
210+
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
211+
size_t lhs = k.f.hash();
212+
size_t rhs = static_cast<size_t>(k.value_index);
213+
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
214+
return lhs;
215+
}
216+
};
217+
} // namespace std
218+
192219
#endif // TVM_IR_H_

include/tvm/operation.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
7777
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
7878
};
7979

80+
/*!
81+
* \brief Symbolic scan.
82+
*/
83+
class ScanOpNode : public OperationNode {
84+
public:
85+
/*! \brief IterVar to scan over */
86+
IterVar scan_axis;
87+
/*! \brief the initialization tensors */
88+
Array<Tensor> init;
89+
/*! \brief the update function represented by tensor */
90+
Array<Tensor> update;
91+
/*! \brief The placeholder to refer as states in update. */
92+
Array<Tensor> state_placeholder;
93+
/*!
94+
* \brief Spatial axis to indicate spatial dimension of each output.
95+
* They corresponds to flattened spatial axis of the outputs.
96+
*
97+
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
98+
* These are auxiliary data structure for storing result of bound inference.
99+
* They do not corresponds to splittable iterations, thus the name comes
100+
* with underscore.
101+
*/
102+
Array<IterVar> spatial_axis_;
103+
/*! \brief constructor */
104+
ScanOpNode() {}
105+
// override behavior.
106+
int num_outputs() const final;
107+
Array<IterVar> root_iter_vars() const final;
108+
Type output_dtype(size_t i) const final;
109+
Array<Expr> output_shape(size_t i) const final;
110+
111+
void VisitAttrs(AttrVisitor* v) final {
112+
v->Visit("name", &name);
113+
v->Visit("scan_axis", &scan_axis);
114+
v->Visit("init", &init);
115+
v->Visit("update", &update);
116+
v->Visit("state_placeholder", &state_placeholder);
117+
v->Visit("spatial_axis_", &spatial_axis_);
118+
}
119+
static Operation make(std::string name,
120+
IterVar axis,
121+
Array<Tensor> init,
122+
Array<Tensor> update,
123+
Array<Tensor> state_placeholder);
124+
125+
static constexpr const char* _type_key = "ScanOp";
126+
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
127+
};
128+
80129

81130
/*! \brief The compute function to specify the input source of a Tensor */
82131
using FCompute = std::function<Expr (const Array<Var>& i)>;
@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
100149
*/
101150
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
102151

152+
/*!
153+
* \brief Construct new tensors by scan over scan_axis.
154+
*
155+
* \param scan_axis The iteration representing the scan.
156+
* \param init The intialize tensor of first K steps.
157+
* \param update The update tensor indicated the updated result after each timestamp.
158+
* \param state_placeholder The placeholder for the states.
159+
* \param name The optional name of the tensor.
160+
*/
161+
Array<Tensor> Scan(IterVar scan_axis,
162+
Array<Tensor> init,
163+
Array<Tensor> update,
164+
Array<Tensor> state_placeholder,
165+
std::string name = "scan");
166+
103167
// same as compute, specialized for different fcompute function
104168
inline Tensor Compute(Array<Expr> shape,
105169
std::function<Expr(Var)> f,

python/tvm/api.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from . import _api_internal
1515
from . import make as _make
1616
from . import expr as _expr
17+
from . import tensor as _tensor
1718
from . import collections as _collections
1819

1920
int32 = "int32"
@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
111112
shape: Tuple of Expr
112113
The shape of the tensor
113114
114-
115115
fcompute: lambda function of *indices-> value
116116
Specifies the input source expression
117117
@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
137137
body = convert(body)
138138
op_node = _api_internal._ComputeOp(
139139
name, dim_var, body)
140-
return _api_internal._Tensor(
141-
shape, body.dtype, op_node, 0)
140+
return op_node.output(0)
141+
142+
143+
def scan(axis, init, update, state_placeholder, name="scan"):
144+
"""Construct new tensors by scanning over axis.
145+
146+
Parameters
147+
----------
148+
axis: IterVar
149+
The scanning axis.
150+
151+
init: Tensor or list of Tensor
152+
The initial condition of first init.shape[0] timestamps
153+
154+
update: Tensor or list of Tensor
155+
The update rule of the scan given by symbolic tensor.
156+
157+
state_placeholder: Tensor or list of Tensor
158+
The placeholder variables used by update.
159+
160+
name: str, optional
161+
The name hint of the tensor
162+
163+
Returns
164+
-------
165+
tensor: tensor.Tensor
166+
The created tensor
167+
168+
Example
169+
-------
170+
# The following code is equivalent to numpy.cumsum
171+
m = tvm.Var("m")
172+
n = tvm.Var("n")
173+
t = tvm.IterVar((1, m), name="t")
174+
X = tvm.placeholder((m, n), name="X")
175+
s_state = tvm.placeholder((m, n))
176+
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
177+
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
178+
res = tvm.scan(t, s_init, s_update, s_state)
179+
"""
180+
if isinstance(init, _tensor.Tensor):
181+
init = [init]
182+
if isinstance(update, _tensor.Tensor):
183+
update = [update]
184+
if isinstance(state_placeholder, _tensor.Tensor):
185+
state_placeholder = [state_placeholder]
186+
if len(init) != len(update) or len(init) != len(state_placeholder):
187+
raise ValueError("init, update, state_placeholder must have same length")
188+
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
189+
res = [op.output(i) for i in range(len(update))]
190+
return (res[0] if len(res) == 1 else res)
142191

143192

144193
def Buffer(shape, dtype=None,

python/tvm/tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,17 @@ def output(self, index):
7474
"""
7575
return _api_internal._OpGetOutput(self, index)
7676

77+
@register_node
78+
class PlaceholderOp(Operation):
79+
"""Placeholder operation."""
80+
pass
81+
7782
@register_node
7883
class ComputeOp(Operation):
7984
"""Compute operation."""
8085
pass
8186

8287
@register_node
83-
class PlaceholderOp(Operation):
84-
"""Placeholder operation."""
88+
class ScanOp(Operation):
89+
"""Scan operation."""
8590
pass

src/api/api_lang.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
173173
args[2]);
174174
});
175175

176+
TVM_REGISTER_API(_ScanOp)
177+
.set_body([](TVMArgs args, TVMRetValue* ret) {
178+
*ret = ScanOpNode::make(args[0],
179+
args[1],
180+
args[2],
181+
args[3],
182+
args[4]);
183+
});
184+
176185
TVM_REGISTER_API(_OpGetOutput)
177186
.set_body([](TVMArgs args, TVMRetValue* ret) {
178187
*ret = args[0].operator Operation().output(

src/arithmetic/canonical.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
365365
const ComExpr& sumb,
366366
int bscale) {
367367
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
368-
n->base = suma->base + sumb->base;
368+
n->base = suma->base + sumb->base * bscale;
369369
// merge of suma and sumb;
370370
size_t i = 0, j = 0;
371371
while (i < suma->elem.size() && j < sumb->elem.size()) {
@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
417417
// convert sum to expr
418418
Expr Sum2Expr(const ComExpr& com, Type t) {
419419
Expr vsum;
420-
if (com->base != 0) {
420+
if (com->base > 0) {
421421
vsum = make_const(t, com->base);
422422
}
423423
for (const ComExprEntry& e : com->elem) {
@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
433433
}
434434
}
435435
}
436+
if (com->base < 0) {
437+
if (vsum.defined()) {
438+
vsum = Sub::make(vsum, make_const(t, -com->base));
439+
} else {
440+
vsum = make_const(t, com->base);
441+
}
442+
}
436443
for (const ComExprEntry& e : com->elem) {
437444
if (e.scale < 0) {
438445
Expr v = e.value;

src/codegen/codegen_cuda.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
168168
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
169169
code = f(code).operator std::string();
170170
}
171-
LOG(INFO) << code;
171+
172172
std::string ptx;
173173
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
174174
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");

src/lang/operation.cc

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <tvm/operation.h>
66
#include <tvm/tensor.h>
77
#include <tvm/ir.h>
8+
#include <tvm/ir_pass.h>
89
#include <memory>
910

1011
namespace tvm {
@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
120121

121122
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
122123

124+
// Scan
125+
inline bool prove_equal(Expr lhs, Expr rhs) {
126+
return is_zero(ir::Simplify(lhs - rhs));
127+
}
128+
129+
int ScanOpNode::num_outputs() const {
130+
return update.size();
131+
}
132+
Array<IterVar> ScanOpNode::root_iter_vars() const {
133+
return Array<IterVar>{scan_axis};
134+
}
135+
136+
Type ScanOpNode::output_dtype(size_t i) const {
137+
return update[i]->dtype;
138+
}
139+
140+
Array<Expr> ScanOpNode::output_shape(size_t i) const {
141+
CHECK_LT(i, state_placeholder.size());
142+
return state_placeholder[i]->shape;
143+
}
144+
145+
Operation ScanOpNode::make(std::string name,
146+
IterVar axis,
147+
Array<Tensor> init,
148+
Array<Tensor> update,
149+
Array<Tensor> state_placeholder) {
150+
auto n = std::make_shared<ScanOpNode>();
151+
CHECK_EQ(init.size(), update.size());
152+
CHECK_EQ(init.size(), state_placeholder.size());
153+
154+
for (size_t i = 0; i < init.size(); ++i) {
155+
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
156+
CHECK_EQ(init[i]->dtype, update[i]->dtype);
157+
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
158+
<< "init.shape[0] need to match scan_axis.dom.min";
159+
CHECK(prove_equal(
160+
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
161+
<< "shate_placeholder.shape[0] need to match"
162+
<< " scan_axis.dom.min + scan_axis.dom.extent";
163+
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
164+
<< "The dimension of init need to match state_placeholder";
165+
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
166+
<< "The update.ndim need to be state_placeholder.ndim - 1";
167+
for (size_t k = 0; k < update[i].ndim(); ++k) {
168+
CHECK(prove_equal(
169+
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
170+
// setup spatial axis
171+
std::ostringstream spatial_name;
172+
spatial_name << name << ".out" << i << ".i" << k + 1;
173+
n->spatial_axis_.push_back(
174+
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
175+
spatial_name.str()));
176+
}
177+
for (size_t k = 1; k < init[i].ndim(); ++k) {
178+
CHECK(prove_equal(
179+
init[i]->shape[k], state_placeholder[i]->shape[k]));
180+
}
181+
}
182+
183+
n->name = name;
184+
n->scan_axis = axis;
185+
n->init = init;
186+
n->update = update;
187+
n->state_placeholder = state_placeholder;
188+
return Operation(n);
189+
}
190+
191+
Array<Tensor> Scan(IterVar scan_axis,
192+
Array<Tensor> init,
193+
Array<Tensor> update,
194+
Array<Tensor> state_placeholder,
195+
std::string name) {
196+
Operation op = ScanOpNode::make(
197+
name, scan_axis, init, update, state_placeholder);
198+
Array<Tensor> res;
199+
for (int i = 0; i < op->num_outputs(); ++i) {
200+
res.push_back(op.output(i));
201+
}
202+
return res;
203+
}
204+
205+
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
206+
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
207+
p->stream << "scan(" << op->name << ", " << op << ")";
208+
});
209+
123210
} // namespace tvm

0 commit comments

Comments
 (0)