Skip to content

Commit 78ea652

Browse files
tqchenicemelon
authored andcommitted
[PASS] Schedule Ops init working version (#6)
* [PASS] Schedule Ops init working version * bugfix in PassUp
1 parent 302c2e6 commit 78ea652

33 files changed

+499
-222
lines changed

HalideIR

Submodule HalideIR updated from 5d1bd10 to 1ec478b

include/tvm/expr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace tvm {
1717

1818
using Halide::Type;
1919
using Halide::Float;
20+
using Halide::Bool;
2021
using Halide::Int;
2122
using Halide::UInt;
2223
using Halide::Handle;
@@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
2930
using Halide::Internal::IRPrinter;
3031
using Halide::Internal::Variable;
3132

33+
using Halide::Internal::make_const;
34+
3235
/*! \brief a named variable in TVM */
3336
class Var : public Halide::VarExpr {
3437
public:

include/tvm/ir_pass.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
namespace tvm {
1919
namespace ir {
2020

21+
22+
/*!
23+
* \brief Schedule s' dependent operations.
24+
*
25+
* \param s The schedule to be realized
26+
* \param dom_map The domain of each iter vars.
27+
* \return the result Stmt
28+
*/
29+
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
30+
2131
/*!
2232
* \brief verifies whether the IR stmt or Expr is in SSA form.
2333
* That is: each VarExpr is defined and assigned once(in Let/For)
@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
5161
Expr body,
5262
Stmt stmt);
5363

54-
/*!
55-
* \brief Schedule s' dependent operations.
56-
*
57-
* \param s The schedule to be realized
58-
* \return the result Stmt
59-
*/
60-
Stmt ScheduelOps(Schedule s);
61-
6264
} // namespace ir
6365
} // namespace tvm
6466

include/tvm/operation.h

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,36 @@
1212

1313
namespace tvm {
1414

15+
/*!
16+
* \brief A placeholder op represents an input placeholder.
17+
*/
18+
class PlaceholderOpNode : public OperationNode {
19+
public:
20+
/*! \brief The shape of the input */
21+
Array<Expr> shape;
22+
/*! \brief The data type of the input. */
23+
Type dtype;
24+
25+
int num_outputs() const final {
26+
return 1;
27+
}
28+
Array<IterVar> root_iter_vars() const final;
29+
Type output_dtype(size_t i) const final;
30+
Array<Expr> output_shape(size_t i) const final;
31+
32+
void VisitAttrs(AttrVisitor* v) final {
33+
v->Visit("name", &name);
34+
v->Visit("shape", &shape);
35+
v->Visit("dtype", &dtype);
36+
}
37+
static Operation make(std::string name,
38+
Array<Expr> shape,
39+
Type dtype);
40+
41+
static constexpr const char* _type_key = "PlaceholderOp";
42+
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
43+
};
44+
1545
/*!
1646
* \brief A Compute op that compute a tensor on certain domain.
1747
*/
@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
2454
/*! \brief constructor */
2555
ComputeOpNode() {}
2656

27-
size_t num_outputs() const final {
57+
int num_outputs() const final {
2858
return 1;
2959
}
3060
Array<IterVar> root_iter_vars() const final;
31-
std::string output_name(size_t i) const final;
3261
Type output_dtype(size_t i) const final;
3362
Array<Expr> output_shape(size_t i) const final;
3463

@@ -49,6 +78,16 @@ class ComputeOpNode : public OperationNode {
4978
/*! \brief The compute function to specify the input source of a Tensor */
5079
using FCompute = std::function<Expr (const Array<Var>& i)>;
5180

81+
/*!
82+
* \brief create a place holder tensor.
83+
* \param shape The shape of the tensor.
84+
* \param dtype the data type of the tensor.
85+
* \param name The name of the Tensor.
86+
*/
87+
Tensor Placeholder(Array<Expr> shape,
88+
Type dtype = Float(32),
89+
std::string name = "placeholder");
90+
5291
/*!
5392
* \brief Construct a new tensor by computing over shape,
5493
* using the computation rule: result_tensor[axis] = fcompute(axis)
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
/*!
22
* Copyright (c) 2016 by Contributors
3-
* \file bound.h
4-
* \brief The bound inference logics on the schedule.
3+
* \file schedule_pass.h
4+
* \brief Collection of Schedule pass functions.
5+
*
6+
* These passes works on the schedule hyper-graph
7+
* and infers information such as bounds, check conditions
8+
* read/write dependencies between the IterVar
59
*/
6-
#ifndef TVM_SCHEDULE_BOUND_H_
7-
#define TVM_SCHEDULE_BOUND_H_
10+
#ifndef TVM_SCHEDULE_PASS_H_
11+
#define TVM_SCHEDULE_PASS_H_
812

9-
#include <tvm/expr.h>
10-
#include <tvm/schedule.h>
11-
#include <unordered_map>
13+
#include "./base.h"
14+
#include "./schedule.h"
1215

1316
namespace tvm {
1417
namespace schedule {
@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);
2326

2427
} // namespace schedule
2528
} // namespace tvm
26-
27-
#endif // TVM_SCHEDULE_BOUND_H_
29+
#endif // TVM_SCHEDULE_PASS_H_

include/tvm/tensor.h

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
2828
* \brief Tensor structure representing a possible input,
2929
* or intermediate computation result.
3030
*/
31-
class Tensor : public FunctionRef {
31+
class Tensor : public NodeRef {
3232
public:
3333
/*! \brief default constructor, used internally */
3434
Tensor() {}
35-
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
36-
/*!
37-
* \brief constructor of input tensor
38-
* \param shape Shape of the tensor.
39-
* \param name optional name of the Tensor.
40-
* \param dtype The data type of the input tensor.
41-
*/
42-
explicit Tensor(Array<Expr> shape,
43-
std::string name = "tensor",
44-
Type dtype = Float(32));
35+
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
4536
/*!
4637
* \brief access the internal node container
4738
* \return the pointer to the internal node container
@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
116107
};
117108

118109
/*! \brief Operation that produces tensors */
119-
class Operation : public NodeRef {
110+
class Operation : public FunctionRef {
120111
public:
121112
/*! \brief default constructor */
122113
Operation() {}
123-
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
114+
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
124115
/*!
125116
* \brief access the internal node container
126117
* \return the pointer to the internal node container
@@ -137,12 +128,10 @@ class Operation : public NodeRef {
137128
};
138129

139130
/*! \brief Node to represent a tensor */
140-
class TensorNode : public FunctionBaseNode {
131+
class TensorNode : public Node {
141132
public:
142133
/*! \brief The shape of the tensor */
143134
Array<Expr> shape;
144-
/*! \brief optional name of the tensor */
145-
std::string name;
146135
/*! \brief data type in the content of the tensor */
147136
Type dtype;
148137
/*! \brief the source operation, can be None */
@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
154143

155144
void VisitAttrs(AttrVisitor* v) final {
156145
v->Visit("shape", &shape);
157-
v->Visit("name", &name);
158146
v->Visit("dtype", &dtype);
159147
v->Visit("op", &op);
160148
v->Visit("value_index", &value_index);
161149
}
162-
const std::string& func_name() const final {
163-
return name;
164-
}
165-
int outputs() const final {
166-
return 1;
167-
}
168150
static Tensor make(Array<Expr> shape,
169-
std::string name,
170151
Type dtype,
171152
Operation op,
172153
int value_index);
@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
178159
/*!
179160
* \brief base class of operation node.
180161
*/
181-
class OperationNode : public Node {
162+
class OperationNode : public FunctionBaseNode {
182163
public:
183164
/*! \brief optional name of the operation */
184165
std::string name;
166+
/*! \return name of the operation */
167+
const std::string& func_name() const final {
168+
return name;
169+
}
170+
/*! \return number of outputs of this op */
171+
virtual int num_outputs() const = 0;
185172
/*! \return the list of iteration variable at root */
186173
virtual Array<IterVar> root_iter_vars() const = 0;
187-
/*! \return number of outputs of this op */
188-
virtual size_t num_outputs() const = 0;
189-
/*! \return name of i-th output */
190-
virtual std::string output_name(size_t i) const = 0;
191174
/*! \return type of i-th output */
192175
virtual Type output_dtype(size_t i) const = 0;
193176
/*! \return shape of i-th output */

python/tvm/function.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
3333
return _function_internal._Var(name, dtype)
3434

3535

36-
def placeholder(shape, dtype = None, name="TensorObj"):
36+
def placeholder(shape, dtype = None, name="placeholder"):
3737
"""Construct an empty tensor object.
3838
3939
Parameters
@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
5353
The created tensor
5454
"""
5555
dtype = float32 if dtype is None else dtype
56-
return _function_internal._Tensor(
57-
shape, name, dtype, None, 0)
56+
return _function_internal._Placeholder(
57+
shape, dtype, name)
5858

5959

60-
def compute(shape, fcompute, name="TensorCompute"):
60+
def compute(shape, fcompute, name="compute"):
6161
"""Construct a new tensor by computing over the shape domain.
6262
6363
The compute rule is result[axis] = fcompute(axis)

python/tvm/tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def __call__(self, *indices):
3434
else:
3535
raise ValueError("The indices must be expression")
3636

37-
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
37+
return _make.Call(self.dtype, self.op.name,
38+
args, _expr.Call.Halide,
39+
self.op, self.value_index)
3840

3941
def __getitem__(self, indices):
4042
return TensorSlice(self, indices)
@@ -71,3 +73,7 @@ def output(self, index):
7173
@register_node
7274
class ComputeOp(Operation):
7375
pass
76+
77+
@register_node
78+
class PlaceholderOp(Operation):
79+
pass

src/c_api/c_api_ir.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
2929
args.at(5));
3030
});
3131

32+
TVM_REGISTER_API(_make_Realize)
33+
.set_body([](const ArgStack& args, RetValue *ret) {
34+
*ret = Realize::make(args.at(0),
35+
args.at(1),
36+
args.at(2),
37+
args.at(3),
38+
args.at(4),
39+
args.at(5));
40+
});
41+
42+
3243
TVM_REGISTER_API(_make_Call)
3344
.set_body([](const ArgStack& args, RetValue *ret) {
3445
*ret = Call::make(args.at(0),
@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
113124
REGISTER_MAKE2(AssertStmt);
114125
REGISTER_MAKE3(ProducerConsumer);
115126
REGISTER_MAKE3(Store);
116-
REGISTER_MAKE3(Provide);
127+
REGISTER_MAKE4(Provide);
117128
REGISTER_MAKE1(Free);
118-
// TODO(tqchen) Realize;
119129
REGISTER_MAKE2(Block);
120130
REGISTER_MAKE3(IfThenElse);
121131
REGISTER_MAKE1(Evaluate);

src/c_api/c_api_lang.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
143143
TVM_REGISTER_API(_Tensor)
144144
.set_body([](const ArgStack& args, RetValue *ret) {
145145
*ret = TensorNode::make(args.at(0),
146-
args.at(1),
147146
args.at(2),
148147
args.at(3),
149148
args.at(4));
@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
160159
std::hash<Tensor>()(args.at(0).operator Tensor()));
161160
});
162161

162+
TVM_REGISTER_API(_Placeholder)
163+
.set_body([](const ArgStack& args, RetValue *ret) {
164+
*ret = Placeholder(args.at(0),
165+
args.at(1),
166+
args.at(2));
167+
});
168+
163169
TVM_REGISTER_API(_ComputeOp)
164170
.set_body([](const ArgStack& args, RetValue *ret) {
165171
*ret = ComputeOpNode::make(args.at(0),

0 commit comments

Comments
 (0)