Skip to content

Commit f5d63d7

Browse files
committed
[TensorIR] Fix tensorir after rebase
1 parent 6b058af commit f5d63d7

File tree

11 files changed

+479
-23
lines changed

11 files changed

+479
-23
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
from .._ffi.node import NodeBase
3+
from .. import make as _make, api as _api, _api_internal
4+
from .tree_node import register_tensorir_node
5+
6+
@register_tensorir_node
7+
class TensorIntrinsic(NodeBase):
8+
def __init__(self, op, intrin_func, name):
9+
self.__init_handle_by_constructor__(_make._TensorIntrinsic, op, intrin_func, name)
10+
11+
def __call__(self, inputs, outputs):
12+
return _api_internal._TensorIntrinsic_Instantiate(self, inputs, outputs)
13+
14+
def decl_tensor_intrin(op, intrin_func, name):
15+
return _make._TensorIntrinsic(op, intrin_func, name)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from .._ffi.node import NodeBase
2+
from .. import make as _make, api as _api, intrin as _intrin, expr as _expr, ir_pass as _ir_pass, \
3+
_api_internal
4+
from ..build_module import current_build_config
5+
6+
from .tree_node import register_tensorir_node
7+
8+
@register_tensorir_node
9+
class TensorRegion(NodeBase):
10+
def __init__(self, tensor_slice):
11+
ranges = []
12+
for x in tensor_slice.indices:
13+
assert x.step is None or x.step == 1, "Only support step = 1"
14+
ranges.append(_api.Range(x.start, x.stop))
15+
self.__init_handle_by_constructor__(_make.TensorRegion,
16+
tensor_slice.tensor,
17+
ranges)
18+
19+
def __getitem__(self, item):
20+
assert len(item) == self.data.ndim, "The dimension of index is wrong"
21+
22+
mins = []
23+
extents = []
24+
for x in item:
25+
mins.append(x.start)
26+
extents.append((x.stop - x.start))
27+
assert x.step is None or x.step == 1, "Only support step == 1"
28+
29+
return _api_internal._TensorRegion_MakeView(self, mins, extents)
30+
31+
def emit_buffer_bind(self, ib, **kwargs):
32+
"""Emit buffer_bind_scope Attr Stmt to an IRBuilder"""
33+
data = self.data
34+
35+
# skip ones todo(lmzheng) : fix this to match inputs placeholder
36+
shape = [_ir_pass.Simplify(x.extent) for x in self.ranges]
37+
while isinstance(shape[0], _expr.IntImm) and shape[0].value == 1 and len(shape) > 1:
38+
shape = shape[1:]
39+
40+
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
41+
tmp = _api.const(1, shape_dtype)
42+
43+
cfg = current_build_config()
44+
B = _api.decl_buffer(shape, dtype=data.dtype, name="B" + data.name,
45+
elem_offset=_api.var(data.name + "_offset", dtype=shape_dtype),
46+
**kwargs)
47+
ranges = []
48+
for x in self.ranges:
49+
ranges.append(x.min)
50+
ranges.append(x.extent)
51+
ib.scope_attr([B, data], "buffer_bind_scope", _intrin.call_intrin("handle", "tvm_tuple", *ranges))
52+
53+
return B

python/tvm/tensorir/tree_node.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from .. import make as _make, expr as _expr
2+
from ..api import _init_api
3+
from .._ffi.node import NodeBase, register_node
4+
5+
def register_tensorir_node(type_key=None):
6+
"""Register a Relay node type.
7+
8+
Parameters
9+
----------
10+
type_key : str or cls
11+
The type key of the node.
12+
"""
13+
if not isinstance(type_key, str):
14+
return register_node(
15+
"tensorir." + type_key.__name__)(type_key)
16+
return register_node(type_key)
17+
18+
19+
@register_tensorir_node
20+
class ScheduleTreeNode(NodeBase):
21+
def __str__(self):
22+
return PrintTreeNode(self)
23+
24+
25+
@register_tensorir_node
26+
class AxisTreeNode(ScheduleTreeNode):
27+
def __init__(self, loop_var, min, extent, axis_type, children):
28+
self.__init_handle_by_constructor__(_make.AxisTreeNode,
29+
loop_var,
30+
min,
31+
extent,
32+
axis_type,
33+
children)
34+
35+
36+
@register_tensorir_node
37+
class BlockTreeNode(ScheduleTreeNode):
38+
def __init__(self, args, vars, inputs, outputs, stmt, children):
39+
stmt = _make.Evaluate(stmt) if isinstance(stmt, _expr.Expr) else stmt
40+
self.__init_handle_by_constructor__(_make.BlockTreeNode,
41+
args,
42+
vars,
43+
inputs,
44+
outputs,
45+
stmt,
46+
children)
47+
48+
_init_api('tvm.tensorir.tree_node')

src/tensorir/intrinsic.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*!
2+
* Copyright (c) 2019 by Contributors
3+
* \brief Tensor intrinsics
4+
*/
5+
6+
#include <tvm/packed_func_ext.h>
7+
#include "tree_node.h"
8+
#include "tree_builder.h"
9+
#include "intrinsic.h"
10+
#include "util.h"
11+
12+
namespace tvm {
13+
namespace tensorir {
14+
15+
// maker
16+
TensorIntrinsic TensorIntrinsicNode::make(Operation op, PackedFunc intrin_func, std::string name) {
17+
NodePtr<TensorIntrinsicNode> node = make_node<TensorIntrinsicNode>();
18+
19+
node->op = std::move(op);
20+
node->intrin_func = std::move(intrin_func);
21+
node->name = std::move(name);
22+
23+
// todo (lmzheng): build BlockTreeNode `from` for checking and untensorize
24+
25+
return TensorIntrinsic(node);
26+
}
27+
28+
const TensorIntrinsicNode* TensorIntrinsic::operator->() const {
29+
return static_cast<const TensorIntrinsicNode*>(node_.get());
30+
}
31+
32+
ScheduleTreeNode TensorIntrinsic::Instantiate(Array<TensorRegion> inputs,
33+
Array<TensorRegion> outputs) const {
34+
NodeRef ret = operator->()->intrin_func(inputs, outputs);
35+
36+
if (ret->derived_from<ScheduleTreeNodeNode>()) {
37+
return Downcast<ScheduleTreeNode>(ret);
38+
} else if (ret->derived_from<StmtNode>()) {
39+
Stmt stmt = Downcast<Stmt>(ret);
40+
Array<Expr> args;
41+
Array<Var> vars;
42+
Set<Var> used_vars;
43+
Map<Var, Expr> var_map;
44+
arith::Analyzer analyzer;
45+
46+
// gather vars
47+
for (const auto& x : inputs) {
48+
for (const auto& ran: x->ranges) {
49+
used_vars.insert(GatherVars(ran->min));
50+
used_vars.insert(GatherVars(ran->extent));
51+
}
52+
}
53+
for (const auto& x : outputs) {
54+
for (const auto& ran: x->ranges) {
55+
used_vars.insert(GatherVars(ran->min));
56+
used_vars.insert(GatherVars(ran->extent));
57+
}
58+
}
59+
60+
// canonicalize outputs
61+
Array<TensorRegion> new_outputs;
62+
std::tie(args, vars, new_outputs, var_map) = CreateOutputRegions(
63+
outputs, used_vars, &analyzer);
64+
65+
// replace inputs
66+
Array<TensorRegion> new_inputs;
67+
for (const auto& x : inputs) {
68+
Array<Range> ranges;
69+
for (const auto& ran: x->ranges) {
70+
ranges.push_back(Range::make_by_min_extent(
71+
SubstituteAndEquationSimplify(ran->min, var_map, &analyzer),
72+
SubstituteAndEquationSimplify(ran->extent, var_map, &analyzer)));
73+
}
74+
new_inputs.push_back(TensorRegionNode::make(x->data, ranges));
75+
}
76+
77+
// replace stmt
78+
stmt = SubstituteAndEquationSimplify(stmt, var_map, &analyzer);
79+
80+
return BlockTreeNodeNode::make(args, vars, new_inputs, new_outputs,
81+
stmt, Array<ScheduleTreeNode>{});
82+
} else {
83+
LOG(FATAL) << "The intrin func returns invalid value";
84+
}
85+
return ScheduleTreeNode(nullptr);
86+
}
87+
88+
} // namespace tensorir
89+
} // namespace tvm

src/tensorir/intrinsic.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*!
2+
* Copyright (c) 2019 by Contributors
3+
* \brief Tensor intrinsics
4+
*/
5+
6+
#ifndef TVM_TENSORIR_INTRINSIC_H_
7+
#define TVM_TENSORIR_INTRINSIC_H_
8+
9+
#include "tree_node.h"
10+
11+
namespace tvm {
12+
namespace tensorir {
13+
14+
using runtime::PackedFunc;
15+
16+
// A tensor intrinsic replaces a block to another block
17+
class TensorIntrinsic;
18+
class TensorIntrinsicNode : public Node {
19+
public:
20+
Operation op; // semantic form
21+
PackedFunc intrin_func; // (Array<TensorRegion>, Array<TensorRegion>) -> ScheduleTreeNode or Stmt,
22+
// todo(lmzheng): use TypedPackedFunc?
23+
std::string name;
24+
25+
void VisitAttrs(AttrVisitor *v) final {
26+
v->Visit("op", &op);
27+
//v->Visit("intrin_func", &intrin_func); // todo(lmzheng): fix AttrVisitor
28+
v->Visit("name", &name);
29+
}
30+
31+
TVM_DLL static TensorIntrinsic make(Operation op, PackedFunc intrin_func, std::string name);
32+
33+
static constexpr const char *_type_key = "tensorir.TensorIntrinsic";
34+
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinsicNode, Node);
35+
};
36+
37+
class TensorIntrinsic : public NodeRef {
38+
public:
39+
TensorIntrinsic() {}
40+
explicit TensorIntrinsic(NodePtr<Node> n): NodeRef(n) {}
41+
42+
const TensorIntrinsicNode* operator->() const;
43+
ScheduleTreeNode Instantiate(Array<TensorRegion> inputs, Array<TensorRegion> outputs) const;
44+
45+
using ContainerType = TensorIntrinsicNode;
46+
};
47+
48+
49+
} // namespace tensorir
50+
} // namespace tvm
51+
52+
53+
#endif // TVM_TENSORIR_INTRINSIC_H_

src/tensorir/schedule.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "schedule.h"
1414
#include "tree_builder.h"
1515
#include "util.h"
16+
#include "../arithmetic/int_set.h"
1617

1718
namespace tvm {
1819

@@ -388,26 +389,26 @@ BlockTreeNode Schedule::compute_at(BlockTreeNode block, AxisTreeNode axis) {
388389
new_args[i] = block->args[i];
389390
}
390391

391-
if (const arith::IntervalSet* set = iter_domain[i].as<arith::IntervalSet>()) {
392+
if (const arith::IntervalSetNode* set = iter_domain[i].as<arith::IntervalSetNode>()) {
392393
node = AxisTreeNodeNode::make(iter_var,
393-
set->i.min,
394-
set->i.max - set->i.min + 1,
394+
set->min_value,
395+
set->max_value - set->min_value + 1,
395396
kOpaque, // todo(lmzheng): fill correct type to replace kOpaque x 3
396397
Array<ScheduleTreeNode>{last});
397398
new_args[i] = iter_var;
398-
} else if (const arith::StrideSet* set = iter_domain[i].as<arith::StrideSet>()) {
399+
} else if (const arith::StrideSetNode* set = iter_domain[i].as<arith::StrideSetNode>()) {
399400
CHECK(set->extents.size() == 1);
400-
CHECK(set->base.is_single_point());
401+
CHECK(is_one(set->base_extent));
401402
if (is_one(set->extents[0])) {
402403
node = AxisTreeNode(nullptr);
403-
new_args[i] = set->base.min;
404+
new_args[i] = set->base_min;
404405
} else {
405406
node = AxisTreeNodeNode::make(iter_var,
406407
0,
407408
set->extents[0],
408409
kOpaque,
409410
Array<ScheduleTreeNode>{last});
410-
new_args[i] = iter_var * set->strides[0] + set->base.min;
411+
new_args[i] = iter_var * set->strides[0] + set->base_min;
411412
}
412413
} else {
413414
LOG(FATAL) << "Cannot handle int set : " << iter_domain[i];
@@ -539,11 +540,11 @@ BlockTreeNode Schedule::blockize(AxisTreeNode axis) {
539540
for (size_t i = 0; i < iter.first.ndim(); ++i) {
540541
Array<IntSet> to_merge;
541542
for (const std::vector<IntSet>& y : iter.second) {
542-
const arith::IntervalSet* set = y[i].as<arith::IntervalSet>();
543+
const arith::IntervalSetNode* set = y[i].as<arith::IntervalSetNode>();
543544
CHECK(set != nullptr);
544-
IntSet b = arith::IntervalSet::make(
545-
SubstituteAndEquationSimplify(set->i.min, var_map, &analyzer),
546-
SubstituteAndEquationSimplify(set->i.max, var_map, &analyzer));
545+
IntSet b = arith::IntervalSet(
546+
SubstituteAndEquationSimplify(set->min_value, var_map, &analyzer),
547+
SubstituteAndEquationSimplify(set->max_value, var_map, &analyzer));
547548
to_merge.push_back(b);
548549
}
549550
IntSet merged = arith::Union(to_merge);
@@ -600,7 +601,7 @@ BlockTreeNode Schedule::tensorize(BlockTreeNode block, TensorIntrinsic intrin) {
600601
block->inputs, block->outputs,
601602
Stmt(NodePtr<Node>(nullptr)),
602603
Array<ScheduleTreeNode>{Downcast<ScheduleTreeNode>(ret)});
603-
} else if (ret->derived_from<HalideIR::Internal::BaseStmtNode>()) {
604+
} else if (ret->derived_from<StmtNode>()) {
604605
new_block = BlockTreeNodeNode::make(block->args, block->vars,
605606
block->inputs, block->outputs,
606607
Downcast<Stmt>(ret), Array<ScheduleTreeNode>{});

0 commit comments

Comments
 (0)