Skip to content

Commit d1be914

Browse files
committed
[TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes.
te::Tensor is an useful object for tensor expression, but brings un-necessary reverse dependency in TIR nodes such as Provide and Realize. This PR is a first step to remove this dependency. We will use Buffer in all the places where the te::Tensor was used. The rough correspondence are: - Provide -> BufferStore - Realize -> BufferRealize - HalideCall -> BufferLoad. After this change, we can not use IRModule of PrimFuncs cleanly to represent TIR at any point of the optimizations. Buffer will serve as the abstraction for the TIR data models to represent the intermediate storages and their constraints. We still keep Realize/HalideCall and Provide as TIR nodes for now to make the change minimum. Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize the temporary IR generated by TE(which contains these nodes) to the TIR. The TIR optimizations are now mostly migrated to to the pass manager. Followup PRs are needed to migrate the remaining few passes.
1 parent 3264895 commit d1be914

40 files changed

+932
-422
lines changed

include/tvm/arith/bound.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
7878
/*!
7979
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
8080
* \param body The given statement.
81-
* \param tensor The name of the calls or provides.
82-
* \param consider_calls If calls (read) are considered.
83-
* \param consider_provides If provides (write) are considered.
81+
* \param buffer The buffer to check the access info.
82+
* \param consider_loads If loads are considered.
83+
* \param consider_stores If stores are considered.
8484
* \return The domain that covers all the calls or provides within the given statement.
8585
*/
86-
Domain DomainTouched(Stmt body,
87-
const te::Tensor &tensor,
88-
bool consider_calls,
89-
bool consider_provides);
86+
Domain DomainTouched(const Stmt& body,
87+
const tir::Buffer& buffer,
88+
bool consider_loads,
89+
bool consider_stores);
9090

9191
} // namespace arith
9292
} // namespace tvm

include/tvm/runtime/memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ObjAllocatorBase {
7070
static_assert(std::is_base_of<Object, T>::value,
7171
"make can only be used to create Object");
7272
T* ptr = Handler::New(static_cast<Derived*>(this),
73-
std::forward<Args>(args)...);
73+
std::forward<Args>(args)...);
7474
ptr->type_index_ = T::RuntimeTypeIndex();
7575
ptr->deleter_ = Handler::Deleter();
7676
return ObjectPtr<T>(ptr);

include/tvm/te/schedule_pass.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#define TVM_TE_SCHEDULE_PASS_H_
3030

3131
#include <tvm/te/schedule.h>
32+
#include <tvm/tir/function.h>
3233

3334
namespace tvm {
3435
namespace te {
@@ -54,6 +55,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
5455
*/
5556
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
5657

58+
/*!
59+
* \brief Postprocessing the Stmt generated by ScheduleOps to create
60+
* a PrimFunc that can then be used for further TIR optimizations.
61+
*
62+
* Perform this translation before running any TIR optimizations.
63+
*
64+
* List of actions taken by the function:
65+
* - Remove occurences of te::Tensor, te::Operation in the IR
66+
* and replace them by corresponding IR nodes via tir::Buffer.
67+
* - Add annotation of extern buffers using the buffer_map field
68+
* in the PrimFunc type.
69+
*
70+
* \param arg_list Array of Tensor/Var/Buffer arguments to the function.
71+
* \param body The body of the function.
72+
* \param bindings potential Tensor to Buffer bindings for the Tensors in the body.
73+
*/
74+
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
75+
Stmt body,
76+
Optional<Map<Tensor, Buffer>> bindings);
77+
5778
/*!
5879
* \brief To automatically inline the element-wise operations.
5980
*

include/tvm/tir/expr.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,10 @@ class CallNode : public PrimExprNode {
694694
ExternCPlusPlus = 1,
695695
/*! \brief Extern "C" without side-effect. */
696696
PureExtern = 2,
697-
/*! \brief Halide-style call, evaluates func(args). */
697+
/*!
698+
* \brief Halide-style call, evaluates func(args).
699+
* \note Deprecated, move to BufferLoad in the future.
700+
*/
698701
Halide = 3,
699702
/*! \brief Intrinsic functions. */
700703
Intrinsic = 4,
@@ -707,9 +710,15 @@ class CallNode : public PrimExprNode {
707710
Array<PrimExpr> args;
708711
/*! \brief Type of calls. */
709712
CallType call_type;
710-
/*! \brief The function to be called. */
713+
/*!
714+
* \brief The function to be called.
715+
* \note Deprecated, move to BufferLoad in the future.
716+
*/
711717
FunctionRef func;
712-
/*! \brief The output value index if func's value is a tuple. */
718+
/*!
719+
* \brief The output value index if func's value is a tuple.
720+
* \note Deprecated, move to BufferLoad in the future.
721+
*/
713722
int value_index{0};
714723

715724
void VisitAttrs(AttrVisitor* v) {

include/tvm/tir/ir_pass.h

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -164,22 +164,6 @@ Stmt Inline(Stmt stmt,
164164
Array<Var> args,
165165
PrimExpr body);
166166

167-
/*!
168-
* \brief Flatten the multi-dimensional read/write
169-
* to single dimensional Load/Store
170-
*
171-
* \param stmt The stmt to be trasnformed.
172-
* \param extern_buffer Map specifies external
173-
* buffer assignment of input and outputs.
174-
* \param cache_line_size The size of CPU cache line.
175-
* \param create_bound_attribute Whether to create bound attributes.
176-
* \return Transformed stmt.
177-
*/
178-
Stmt StorageFlatten(Stmt stmt,
179-
Map<te::Tensor, Buffer> extern_buffer,
180-
int cache_line_size,
181-
bool create_bound_attribute = false);
182-
183167
/*!
184168
* \brief Try to modify the AST to support TensorCore
185169
*
@@ -202,13 +186,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
202186
*/
203187
bool VerifyCompactBuffer(Stmt stmt);
204188

205-
/*!
206-
* \brief Inject prefetch instructions into stmt.
207-
* \param stmt The statement to be transformed.
208-
* \return Transformed stmt.
209-
*/
210-
Stmt InjectPrefetch(Stmt stmt);
211-
212189
/*!
213190
* \brief Decorate the stmt with a device scope, this is helpful for
214191
* hardware accelerator without thread blocks.

include/tvm/tir/stmt.h

Lines changed: 98 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ class StoreNode : public StmtNode {
248248
* \endcode
249249
* \sa BufferLoad
250250
*/
251-
class BufferStore;
252251
class BufferStoreNode : public StmtNode {
253252
public:
254253
/*! \brief The buffer variable. */
@@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode {
281280
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
282281
};
283282

283+
/*!
284+
* \brief Managed reference to BufferStoreNode.
285+
* \sa BufferStoreNode
286+
*/
284287
class BufferStore : public Stmt {
285288
public:
286289
TVM_DLL explicit BufferStore(Buffer buffer,
@@ -289,8 +292,80 @@ class BufferStore : public Stmt {
289292
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
290293
};
291294

295+
/*!
296+
* \brief Annotate the region where the buffer need to
297+
* be read and write in the body.
298+
* We only need to allocate the space for the corresponding region.
299+
*
300+
* \note There should be at most one BufferRealize for each buffer.
301+
* BufferRealize is not necessary for external buffers,
302+
* since they are assumed to be fully allocated.
303+
*
304+
* \sa BufferLoad, BufferStore
305+
*/
306+
class BufferRealizeNode : public StmtNode {
307+
public:
308+
/*! \brief The buffer variable. */
309+
Buffer buffer;
310+
/*! \brief Bounds to be realized */
311+
Array<Range> bounds;
312+
/*! \brief Only realize if condition holds. */
313+
PrimExpr condition;
314+
/*! \brief The body of realization. */
315+
Stmt body;
316+
317+
void VisitAttrs(AttrVisitor* v) {
318+
v->Visit("buffer", &buffer);
319+
v->Visit("bounds", &bounds);
320+
v->Visit("condition", &condition);
321+
v->Visit("body", &body);
322+
}
323+
324+
bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
325+
return
326+
equal(buffer, other->buffer) &&
327+
equal(bounds, other->bounds) &&
328+
equal(condition, other->condition) &&
329+
equal(body, other->body);
330+
}
331+
332+
void SHashReduce(SHashReducer hash_reduce) const {
333+
hash_reduce(buffer);
334+
hash_reduce(bounds);
335+
hash_reduce(condition);
336+
hash_reduce(body);
337+
}
338+
339+
BufferRealizeNode() = default;
340+
BufferRealizeNode(Buffer buffer,
341+
Array<Range> bounds,
342+
PrimExpr condition,
343+
Stmt body)
344+
: buffer(buffer), bounds(bounds),
345+
condition(condition), body(body) {}
346+
347+
static constexpr const char* _type_key = "BufferRealize";
348+
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
349+
};
350+
351+
/*!
352+
* \brief Managed reference to BufferRealizeNode.
353+
* \sa BufferRealizeNode
354+
*/
355+
class BufferRealize : public Stmt {
356+
public:
357+
TVM_DLL explicit BufferRealize(Buffer buffer,
358+
Array<Range> bounds,
359+
PrimExpr condition,
360+
Stmt body);
361+
362+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode);
363+
};
364+
292365
/*!
293366
* \brief Store value into mult-dimensional array defined by func.
367+
*
368+
* \note Deprecated, move to BufferStore in the future.
294369
*/
295370
class ProvideNode : public StmtNode {
296371
public:
@@ -430,6 +505,8 @@ class FreeNode : public StmtNode {
430505
/*!
431506
* \brief Annotate the bounds where func need to be written and read in body.
432507
* We will need to allocate space for the corresponding regions.
508+
*
509+
* \note Deprecated, move to BufferRealize in the future.
433510
*/
434511
class RealizeNode : public StmtNode {
435512
public:
@@ -747,50 +824,50 @@ class ForNode : public StmtNode {
747824
};
748825

749826
/*!
750-
* \brief A prefetch hint of func.
827+
* \brief A prefetch hint for abuffer
751828
*/
752829
class PrefetchNode : public StmtNode {
753830
public:
754831
/*! \brief The function to be prefetched. */
755-
FunctionRef func;
756-
/*! \brief The output value index if func's value is a tuple. */
757-
int value_index;
758-
/*! \brief The data type of the array. */
759-
DataType dtype;
832+
Buffer buffer;
760833
/*! \brief Bounds to be prefetched. */
761-
Region bounds;
834+
Array<Range> bounds;
762835

763836
void VisitAttrs(AttrVisitor* v) {
764-
v->Visit("func", &func);
765-
v->Visit("value_index", &value_index);
766-
v->Visit("dtype", &dtype);
837+
v->Visit("buffer", &buffer);
767838
v->Visit("bounds", &bounds);
768839
}
769840

770841
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
771842
return
772-
equal(func, other->func) &&
773-
equal(value_index, other->value_index) &&
774-
equal(dtype, other->dtype) &&
843+
equal(buffer, other->buffer) &&
775844
equal(bounds, other->bounds);
776845
}
777846

778847
void SHashReduce(SHashReducer hash_reduce) const {
779-
hash_reduce(func);
780-
hash_reduce(value_index);
781-
hash_reduce(dtype);
848+
hash_reduce(buffer);
782849
hash_reduce(bounds);
783850
}
784851

785-
TVM_DLL static Stmt make(FunctionRef func,
786-
int value_index,
787-
DataType dtype,
788-
Region bounds);
852+
PrefetchNode() = default;
853+
PrefetchNode(Buffer buffer, Array<Range> bounds)
854+
: buffer(buffer), bounds(bounds) {}
789855

790856
static constexpr const char* _type_key = "Prefetch";
791857
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
792858
};
793859

860+
/*!
861+
* \brief Managed reference to PrefetchNode.
862+
* \sa PrefetchNode
863+
*/
864+
class Prefetch : public Stmt {
865+
public:
866+
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds);
867+
868+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
869+
};
870+
794871
/*!
795872
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
796873
*/

include/tvm/tir/stmt_functor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
9292
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9393
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9494
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
95+
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9596
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9697
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9798
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -121,6 +122,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
121122
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
122123
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
123124
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
125+
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
126+
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
124127
return vtable;
125128
}
126129
};
@@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor :
154157
void VisitStmt_(const AllocateNode* op) override;
155158
void VisitStmt_(const StoreNode* op) override;
156159
void VisitStmt_(const BufferStoreNode* op) override;
160+
void VisitStmt_(const BufferRealizeNode* op) override;
157161
void VisitStmt_(const FreeNode* op) override;
158162
void VisitStmt_(const AssertStmtNode* op) override;
159163
void VisitStmt_(const ProvideNode* op) override;
@@ -248,6 +252,7 @@ class TVM_DLL StmtMutator :
248252
Stmt VisitStmt_(const AllocateNode* op) override;
249253
Stmt VisitStmt_(const StoreNode* op) override;
250254
Stmt VisitStmt_(const BufferStoreNode* op) override;
255+
Stmt VisitStmt_(const BufferRealizeNode* op) override;
251256
Stmt VisitStmt_(const FreeNode* op) override;
252257
Stmt VisitStmt_(const AssertStmtNode* op) override;
253258
Stmt VisitStmt_(const ProvideNode* op) override;

include/tvm/tir/transform.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
5858
const std::string& name,
5959
const tvm::Array<runtime::String>& required);
6060

61+
62+
/*!
63+
* \brief Inject prefetch instructions into stmt.
64+
*
65+
* \return The pass.
66+
*/
67+
TVM_DLL Pass InjectPrefetch();
68+
69+
// TODO(tvm-team): consolidate configs to the PassContext
70+
/*!
71+
* \brief Flatten the multi-dimensional read/write
72+
* to single dimensional Load/Store
73+
*
74+
* \param cache_line_size The size of CPU cache line.
75+
* \param create_bound_attribute Whether to create bound attributes.
76+
*
77+
* \return The Pass
78+
*/
79+
TVM_DLL Pass StorageFlatten(int cache_line_size,
80+
bool create_bound_attribute = false);
81+
6182
/*!
6283
* \brief Inject copy intrinsics with optional pad.
6384
*

python/tvm/autotvm/feature.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import tvm._ffi
3232

3333
from tvm import target as _target
34-
from tvm.tir import ir_pass
3534
from tvm.te import schedule
3635
from tvm.driver import build_module
3736

@@ -46,10 +45,12 @@ def ana_lower(sch, args,
4645
# Phase 0
4746
bounds = schedule.InferBound(sch)
4847
stmt = schedule.ScheduleOps(sch, bounds, True)
49-
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
50-
stmt = ir_pass.CanonicalSimplify(stmt)
48+
func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
49+
mod = tvm.IRModule.from_expr(func._move())
50+
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
51+
mod = tvm.tir.transform.Simplify()(mod._move())
5152
assert simple_mode
52-
return stmt
53+
return mod["main"].body
5354

5455
try:
5556
_get_buffer_curve_sample_flatten = tvm._ffi.get_global_func(

0 commit comments

Comments
 (0)