Skip to content

Commit 302c2e6

Browse files
committed
Make Tensor comparator and hash to be aware of same op and index, init checkin of the ir generation
1 parent eee0ebe commit 302c2e6

File tree

9 files changed

+311
-6
lines changed

9 files changed

+311
-6
lines changed

include/tvm/operation.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,17 @@ inline Tensor Compute(Array<Expr> shape,
8686

8787
} // namespace tvm
8888

89+
90+
namespace std {
91+
template <>
92+
struct hash<::tvm::Tensor> {
93+
std::size_t operator()(const ::tvm::Tensor& k) const {
94+
if (k.defined() && k->op.defined()) {
95+
return k->op.hash();
96+
} else{
97+
return k.hash();
98+
}
99+
}
100+
};
101+
} // namespace std
89102
#endif // TVM_OPERATION_H_

include/tvm/tensor.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ class Tensor : public FunctionRef {
4747
* \return the pointer to the internal node container
4848
*/
4949
inline const TensorNode* operator->() const;
50+
/*!
51+
* \brief check if two tensors equals each other.
52+
* \param other tensor to be checked.
53+
* \return whether the two tensors equals each other.
54+
*/
55+
inline bool operator==(const Tensor& other) const;
5056
/*! \return The dimension of the tensor */
5157
inline size_t ndim() const;
5258
/*!
@@ -201,6 +207,17 @@ inline size_t Tensor::ndim() const {
201207
return (*this)->shape.size();
202208
}
203209

210+
inline bool Tensor::operator==(const Tensor& other) const {
211+
if (get() == other.get()) return true;
212+
if (get() == nullptr || other.get() == nullptr) return false;
213+
if ((*this)->op.defined() || other->op.defined()) {
214+
return (*this)->op == other->op &&
215+
(*this)->value_index == other->value_index;
216+
} else {
217+
return false;
218+
}
219+
}
220+
204221
// macro to turn every operation of slice to expression
205222
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
206223
inline Expr operator Op (const Tensor::Slice& a) { \

python/tvm/_ctypes/_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ class ArgVariant(ctypes.Union):
2929
def _type_key(handle):
3030
ret_val = ArgVariant()
3131
ret_typeid = ctypes.c_int()
32+
ret_success = ctypes.c_int()
3233
check_call(_LIB.TVMNodeGetAttr(
3334
handle, c_str("type_key"),
34-
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
35+
ctypes.byref(ret_val),
36+
ctypes.byref(ret_typeid),
37+
ctypes.byref(ret_success)))
3538
return py_str(ret_val.v_str)
3639

3740
NODE_TYPE = {

python/tvm/tensor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import absolute_import as _abs
22
from ._ctypes._api import NodeBase, SliceBase, register_node, convert
33
from . import collections as _collections
4+
from . import _function_internal
45
from . import make as _make
56
from . import expr as _expr
67

@@ -38,6 +39,35 @@ def __call__(self, *indices):
3839
def __getitem__(self, indices):
3940
return TensorSlice(self, indices)
4041

42+
def __hash__(self):
43+
return _function_internal._TensorHash(self)
44+
45+
def __eq__(self, other):
46+
if not isinstance(other, Tensor):
47+
return False
48+
return _function_internal._TensorEqual(self, other)
49+
4150
@property
4251
def ndim(self):
4352
return len(self.shape)
53+
54+
55+
class Operation(NodeBase):
56+
def output(self, index):
57+
"""Get the index-th output of the operation
58+
59+
Parameters
60+
----------
61+
index : int
62+
The index size.
63+
64+
Returns
65+
-------
66+
out : Tensor
67+
The i-th output.
68+
"""
69+
return _function_internal._OpGetOutput(self, index)
70+
71+
@register_node
72+
class ComputeOp(Operation):
73+
pass

src/c_api/c_api_lang.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,30 @@ TVM_REGISTER_API(_Tensor)
149149
args.at(4));
150150
});
151151

152+
TVM_REGISTER_API(_TensorEqual)
153+
.set_body([](const ArgStack& args, RetValue *ret) {
154+
*ret = args.at(0).operator Tensor() == args.at(1).operator Tensor();
155+
});
156+
157+
TVM_REGISTER_API(_TensorHash)
158+
.set_body([](const ArgStack& args, RetValue *ret) {
159+
*ret = static_cast<int64_t>(
160+
std::hash<Tensor>()(args.at(0).operator Tensor()));
161+
});
162+
152163
TVM_REGISTER_API(_ComputeOp)
153164
.set_body([](const ArgStack& args, RetValue *ret) {
154165
*ret = ComputeOpNode::make(args.at(0),
155166
args.at(1),
156167
args.at(2));
157168
});
158169

170+
TVM_REGISTER_API(_OpGetOutput)
171+
.set_body([](const ArgStack& args, RetValue *ret) {
172+
*ret = args.at(0).operator Operation().output(
173+
args.at(1).operator size_t());
174+
});
175+
159176

160177
TVM_REGISTER_API(_IterVar)
161178
.set_body([](const ArgStack& args, RetValue *ret) {

src/pass/schedule_ops.cc

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,225 @@
55
#include <tvm/ir.h>
66
#include <tvm/ir_mutator.h>
77
#include <tvm/ir_pass.h>
8+
#include <tvm/ir_visitor.h>
89
#include "./scope.h"
910

1011
namespace tvm {
1112
namespace ir {
1213
namespace {
14+
15+
/*!
16+
* \brief use message passing to calculate the assignment of each Var inside the loop body.
17+
* \param s The schedule to be used.
18+
* \param dom_map The domain map of each iteration variable's domain
19+
* \param p_state The message passing state
20+
* IterVar->The assignment.
21+
*/
22+
void PassUpOffset(const Schedule& s,
23+
const std::unordered_map<IterVar, Range>& dom_map,
24+
std::unordered_map<IterVar, Expr>* p_state) {
25+
auto& state = *p_state;
26+
for (size_t i = s->relations.size(); i != 0; --i) {
27+
IterVarRelation rel = s->relations[i - 1];
28+
if (rel.as<SplitNode>()) {
29+
const SplitNode* s = rel.as<SplitNode>();
30+
Expr outer = state.at(s->outer);
31+
Expr inner = state.at(s->outer);
32+
Expr factor = dom_map.at(s->outer)->extent;
33+
Expr offset = inner + outer * factor;
34+
Expr outer_min = dom_map.at(s->parent)->min;
35+
if (!is_zero(outer_min)) {
36+
offset = outer_min + offset;
37+
}
38+
state[s->parent] = offset;
39+
} else if (rel.as<FuseNode>()) {
40+
const FuseNode* s = rel.as<FuseNode>();
41+
Expr value = state.at(s->fused);
42+
Expr factor = dom_map.at(s->outer)->extent;
43+
state[s->outer] = value / factor;
44+
state[s->inner] = value % factor;
45+
} else {
46+
LOG(FATAL) << "unknown relation type";
47+
}
48+
}
49+
}
50+
51+
/*!
52+
* \brief split the expr by addition.
53+
* \param expr The expression to be splitted.
54+
* \param loop_level The loop level of each Variable
55+
* \param result vector of (level, expr)
56+
* The level gives the mimimum loop level this expression need to be computed.
57+
* The Expr gives the expression content.
58+
*/
59+
void SplitByAdd(Expr expr,
60+
const std::unordered_map<const Variable*, size_t>& loop_level,
61+
std::vector<std::pair<size_t, Expr> > *result) {
62+
const Add* op = expr.as<Add>();
63+
if (op != nullptr) {
64+
SplitByAdd(op->a, loop_level, result);
65+
SplitByAdd(op->b, loop_level, result);
66+
} else {
67+
size_t max_level = 0;
68+
auto fvisit = [&max_level, &loop_level](const NodeRef& n) {
69+
const Variable* op = n.as<Variable>();
70+
if (op != nullptr) {
71+
auto it = loop_level.find(op);
72+
if (it != loop_level.end()) {
73+
max_level = std::max(max_level, it->second);
74+
}
75+
}
76+
};
77+
PostOrderVisit(expr, fvisit);
78+
result->push_back(std::make_pair(max_level, expr));
79+
}
80+
}
81+
82+
/*!
83+
* \brief combine the nest stmt, whose body is not defined.
84+
* \param nest A list of For and LetStmt, whose body is not defined.
85+
* \param body body
86+
*/
87+
Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) {
88+
while (!nest.empty()) {
89+
Stmt s = std::move(nest.back());
90+
nest.pop_back();
91+
if (s.as<For>()) {
92+
auto n = std::make_shared<For>(*s.as<For>());
93+
n->body = body;
94+
body = Stmt(n);
95+
} else if (s.as<LetStmt>()) {
96+
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
97+
n->body = body;
98+
body = Stmt(n);
99+
} else if (s.as<AttrStmt>()) {
100+
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
101+
n->body = body;
102+
body = Stmt(n);
103+
} else {
104+
LOG(FATAL) << "not supported nest type";
105+
}
106+
}
107+
return body;
108+
}
109+
110+
/*!
111+
* \brief Make the loop nest of the correspondings schedule.
112+
* \param sch The schedule.
113+
* \param dom_map The domain map.
114+
*/
115+
std::vector<Stmt> MakeLoopNest(
116+
const Schedule& sch,
117+
const std::unordered_map<IterVar, Range>& dom_map) {
118+
// optional, use let to define some CSE in dom_map.
119+
auto leaf_iter_vars = sch->leaf_iter_vars;
120+
std::unordered_map<IterVar, Expr> offset;
121+
std::unordered_map<const Variable*, size_t> loop_level;
122+
123+
// create the loop nest
124+
std::vector<Stmt> nest;
125+
nest.resize(leaf_iter_vars.size() + 1, Stmt());
126+
127+
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
128+
auto iv = leaf_iter_vars[i];
129+
// initialize the offset and loop_level
130+
offset[iv] = iv->var;
131+
loop_level[iv->var.as<Variable>()] = i + 1;
132+
133+
nest[i] = AttrStmt::make(iv->var, "scope", iv, Stmt());
134+
if (iv->thread_tag.length() == 0) {
135+
Range dom = dom_map.at(iv);
136+
nest[i] = For::make(iv->var, dom->min, dom->extent,
137+
ForType::Serial, DeviceAPI::None, nest[i]);
138+
}
139+
}
140+
// message passing to get offset of root iter vars.
141+
PassUpOffset(sch, dom_map, &offset);
142+
for (IterVar iv : sch->op->root_iter_vars()) {
143+
Expr value = offset.at(iv);
144+
if (value.same_as(iv->var)) continue;
145+
using Entry = std::pair<size_t, Expr>;
146+
std::vector<Entry> splits;
147+
SplitByAdd(value, loop_level, &splits);
148+
149+
Expr offset = 0;
150+
for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
151+
auto iv = leaf_iter_vars[i];
152+
for (const auto& kv : splits) {
153+
if (kv.first == i) {
154+
offset = offset + splits[i].second;
155+
}
156+
}
157+
std::ostringstream os;
158+
os << iv->var->name_hint << ".at.l" << i;
159+
Var base_offset(os.str());
160+
nest[i] = LetStmt::make(base_offset, offset, nest[i]);
161+
offset = base_offset;
162+
}
163+
nest.back() = LetStmt::make(iv->var, offset, nest.back());
164+
}
165+
return nest;
166+
}
167+
168+
/*!
169+
* \brief Make the loop nest of the correspondings schedule.
170+
* \param op The operation.
171+
*/
172+
Stmt MakeBody(const Operation& op) {
173+
Stmt body;
174+
if (op.as<ComputeOpNode>()) {
175+
const ComputeOpNode* compute = op.as<ComputeOpNode>();
176+
// Note: Tensor's address cannot uniquely
177+
Tensor t = op.output(0);
178+
Array<Expr> args;
179+
for (IterVar iv : compute->axis) {
180+
args.push_back(iv->var);
181+
}
182+
body = Provide::make(t, {compute->body}, args);
183+
} else {
184+
LOG(FATAL) << "not supported op";
185+
}
186+
return body;
187+
}
188+
189+
Stmt MakePipeline(const Schedule& sch, Stmt body) {
190+
return body;
191+
}
192+
193+
// inject the operator's realization on the stmt.
194+
class InjectRealize : public IRMutator {
195+
public:
196+
explicit InjectRealize(Schedule sch)
197+
: sch_(sch) {}
198+
199+
Stmt Mutate(Stmt stmt) final {
200+
const AttrStmt* op = stmt.as<AttrStmt>();
201+
if (op != nullptr) {
202+
attr_scope_.Push({op->node, op->type_key}, op->value);
203+
stmt = IRMutator::Mutate(stmt);
204+
attr_scope_.Pop({op->node, op->type_key});
205+
} else {
206+
stmt = IRMutator::Mutate(stmt);
207+
}
208+
209+
if (op != nullptr &&
210+
op->type_key == "scope" &&
211+
op->node == sch_->attach_parent) {
212+
return AttrStmt::make(
213+
op->node, op->type_key, op->value,
214+
MakePipeline(sch_, op->body));
215+
} else {
216+
return stmt;
217+
}
218+
}
219+
220+
private:
221+
// the operations to be carried
222+
Schedule sch_;
223+
Scope<AttrKey, Expr> attr_scope_;
224+
};
225+
226+
13227
} // namespace
14228
} // namespace ir
15229
} // namespace tvm

src/schedule/bound.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ void PassToOperation(
101101
const Tensor& tensor,
102102
const std::vector<IntSet>& dim_bounds,
103103
std::unordered_map<IterVar, std::vector<IntSet> >* result) {
104-
104+
// This is a push style operation, given output bound, push to the op IterVar bound.
105+
// It cannot handle complicated cases where op bound is coupled with bounds of
106+
// all of its outputs, without having a simple communicative union relation.
107+
//
108+
// Eventually, we need to change the inference to be a Pull style inference
105109
if (tensor->op.as<ComputeOpNode>()) {
106110
auto root_iter_vars = tensor->op->root_iter_vars();
107111
CHECK_EQ(tensor.ndim(), root_iter_vars.size());

0 commit comments

Comments
 (0)