Skip to content

Commit ae9b5ca

Browse files
tqchentrevor-m
authored andcommitted
[TIR][REFACTOR] RewriteForTensorCore -> te/schedule (apache#5379)
* [TIR][REFACTIR] RewriteForTensorCore -> te/schedule RewriteForTensor depends on the schedule information, which makes it differ from a typical pass(which should get all the information from the input TIR). As a result, we refactor it as a SchedulePostProc step for now. We should revisit it later as we introduce more support for tensor core patterns in the TIR. * Fix VTA to fit the new IR Pattern
1 parent 5266afc commit ae9b5ca

File tree

7 files changed

+98
-82
lines changed

7 files changed

+98
-82
lines changed

include/tvm/te/schedule_pass.h

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@
3434
namespace tvm {
3535
namespace te {
3636

37+
/*!
38+
* \brief To automatically inline the element-wise operations.
39+
*
40+
* \param sch The schedule to be inlined.
41+
*/
42+
void AutoInlineElemWise(Schedule sch);
43+
44+
/*!
45+
* \brief To automatically inline operations with injective writes
46+
* (i.e. writes without reduction or sequential loops). Note
47+
* that in this case, guarantees about contiguity, transpose, stride,
48+
* alignemnt and memory footprint in general do not hold.
49+
*
50+
* \param sch The schedule to be inlined.
51+
*/
52+
TVM_DLL void AutoInlineInjective(Schedule sch);
53+
3754
/*!
3855
* \brief Infer the bound of all iteration variables relates to the schedule.
3956
*
@@ -55,6 +72,21 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
5572
*/
5673
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
5774

75+
76+
/*!
77+
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
78+
*
79+
* \param stmt The stmt to be trasnformed.
80+
* \param schedule The original schedule.
81+
* \param extern_buffer Map specifies external
82+
* buffer assignment of input and outputs.
83+
* \return Transformed stmt.
84+
*/
85+
Stmt SchedulePostProcRewriteForTensorCore(
86+
Stmt stmt,
87+
Schedule schedule,
88+
Map<Tensor, Buffer> extern_buffer);
89+
5890
/*!
5991
* \brief Postprocessing the Stmt generated by ScheduleOps to create
6092
* a PrimFunc that can then be used for further TIR optimizations.
@@ -75,23 +107,6 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
75107
Stmt body,
76108
Optional<Map<Tensor, Buffer>> bindings);
77109

78-
/*!
79-
* \brief To automatically inline the element-wise operations.
80-
*
81-
* \param sch The schedule to be inlined.
82-
*/
83-
void AutoInlineElemWise(Schedule sch);
84-
85-
/*!
86-
* \brief To automatically inline operations with injective writes
87-
* (i.e. writes without reduction or sequential loops). Note
88-
* that in this case, guarantees about contiguity, transpose, stride,
89-
* alignemnt and memory footprint in general do not hold.
90-
*
91-
* \param sch The schedule to be inlined.
92-
*/
93-
TVM_DLL void AutoInlineInjective(Schedule sch);
94-
95110
} // namespace te
96111
} // namespace tvm
97112
#endif // TVM_TE_SCHEDULE_PASS_H_

include/tvm/tir/ir_pass.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,19 +164,6 @@ Stmt Inline(Stmt stmt,
164164
Array<Var> args,
165165
PrimExpr body);
166166

167-
/*!
168-
* \brief Try to modify the AST to support TensorCore
169-
*
170-
* \param stmt The stmt to be trasnformed.
171-
* \param schedule The original schedule.
172-
* \param extern_buffer Map specifies external
173-
* buffer assignment of input and outputs.
174-
* \return Transformed stmt.
175-
*/
176-
Stmt RewriteForTensorCore(Stmt stmt,
177-
te::Schedule schedule,
178-
Map<te::Tensor, Buffer> extern_buffer);
179-
180167
/*!
181168
* \brief Verify if there is any argument bound to compact buffer.
182169
*

python/tvm/driver/build_module.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,23 +84,43 @@ def get_binds(args, compact=False, binds=None):
8484
return binds, arg_list
8585

8686

87-
def form_body(sch):
87+
def form_irmodule(sch, args, name, binds):
8888
"""According to the given schedule, form a function.
8989
9090
Parameters
9191
----------
9292
sch : tvm.te.schedule.Schedule
93-
The given scheduler to form the raw body
93+
The given scheduler to form the raw body
94+
95+
args : list of Buffer or Tensor or Var
96+
The argument lists to the function.
97+
98+
name : str
99+
The name of result function.
100+
101+
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
102+
The binds information
94103
95104
Returns
96105
-------
97106
The body formed according to the given schedule
98107
"""
99108
# normalize schedule first
109+
cfg = BuildConfig.current()
100110
sch = sch.normalize()
101111
bounds = schedule.InferBound(sch)
102112
stmt = schedule.ScheduleOps(sch, bounds)
103-
return stmt
113+
114+
compact = ir_pass.VerifyCompactBuffer(stmt)
115+
binds, arg_list = get_binds(args, compact, binds)
116+
117+
stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
118+
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
119+
120+
func = func.with_attr("global_symbol", name)
121+
if cfg.restricted_func:
122+
func = func.with_attr("tir.noalias", True)
123+
return tvm.IRModule({name: func})
104124

105125

106126
def _wrap_as_prim_func_pass(flist, name):
@@ -166,24 +186,13 @@ def lower(sch,
166186

167187
# Phase 0
168188
if isinstance(sch, schedule.Schedule):
169-
stmt = form_body(sch)
170-
171-
for f in lower_phase0:
172-
stmt = f(stmt)
173-
174-
compact = ir_pass.VerifyCompactBuffer(stmt)
175-
binds, arg_list = get_binds(args, compact, binds)
176-
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
177-
178-
# Start the new style pass manager.
179-
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
180-
func = func.with_attr("global_symbol", name)
181-
if cfg.restricted_func:
182-
func = func.with_attr("tir.noalias", True)
183-
mod = tvm.IRModule({name: func})
189+
mod = form_irmodule(sch, args, name, binds)
190+
else:
191+
mod = sch
184192

185193
# Phase 1
186194
pass_list = [
195+
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
187196
tvm.tir.transform.InjectPrefetch(),
188197
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
189198
tvm.tir.transform.NarrowDataType(32),

python/tvm/te/hybrid/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# 2. Support multi-level HalideIR
3131
import inspect
3232
import tvm._ffi
33-
from tvm.driver.build_module import form_body
33+
import tvm.te.schedule
3434
from tvm._ffi.base import decorate
3535

3636
from .module import HybridModule
@@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"):
8787
The built results is wrapped in a HybridModule.
8888
The usage of HybridModule is roughly the same as normal TVM-built modules.
8989
"""
90+
sch = sch.normalize()
91+
bounds = tvm.te.schedule.InferBound(sch)
92+
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
9093

91-
stmt = form_body(sch)
9294
src = _Dump(stmt, inputs, outputs, name)
9395

9496
return HybridModule(src, name)

src/tir/pass/tensor_core.cc renamed to src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
*/
1919

2020
/*!
21-
* \file tensor_core.cc
21+
* \file schedule_postproc_rewrite_for_tensor_core.cc
22+
*
23+
* \brief Rewrite the Stmt generated by ScheduleOps
24+
* to accomondate tensorcore.
2225
*/
23-
// IR Passes for TensorCore CodeGen
26+
#include <tvm/runtime/registry.h>
2427
#include <tvm/tir/expr.h>
2528
#include <tvm/tir/stmt.h>
2629
#include <tvm/te/operation.h>
@@ -32,12 +35,11 @@
3235
#include <tvm/target/target.h>
3336
#include <tvm/runtime/device_api.h>
3437
#include <unordered_map>
35-
#include "ir_util.h"
3638
#include "../../arith/compute_expr.h"
3739
#include "../../runtime/thread_storage_scope.h"
3840

3941
namespace tvm {
40-
namespace tir {
42+
namespace te {
4143

4244
using namespace te;
4345
using runtime::StorageRank;
@@ -86,10 +88,10 @@ class MMAMatcher: public StmtVisitor {
8688
}
8789

8890
void VisitStmt_(const AttrStmtNode* op) final {
89-
if (op->attr_key == attr::pragma_tensor_core) {
91+
if (op->attr_key == tir::attr::pragma_tensor_core) {
9092
tensor_core_on_ = true;
9193
StmtVisitor::VisitStmt_(op);
92-
} else if (op->attr_key == attr::realize_scope) {
94+
} else if (op->attr_key == tir::attr::realize_scope) {
9395
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
9496
this->VisitStmt(op->body);
9597
} else {
@@ -414,18 +416,18 @@ class BufferAnalyser : public StmtExprVisitor {
414416
}
415417

416418
void VisitStmt_(const AttrStmtNode* op) final {
417-
if (op->attr_key == attr::thread_extent) {
419+
if (op->attr_key == tir::attr::thread_extent) {
418420
if (const IntImmNode* value = op->value.as<IntImmNode>()) {
419421
thread_extent_.insert(
420422
std::make_pair(
421423
op->node.as<IterVarNode>()->var->name_hint,
422424
value->value));
423425
}
424426
StmtExprVisitor::VisitStmt_(op);
425-
} else if (op->attr_key == attr::realize_scope) {
427+
} else if (op->attr_key == tir::attr::realize_scope) {
426428
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
427429
this->VisitStmt(op->body);
428-
} else if (op->attr_key == attr::buffer_dim_align) {
430+
} else if (op->attr_key == tir::attr::buffer_dim_align) {
429431
te::Tensor tensor = Downcast<te::Tensor>(op->node);
430432
const CallNode* tuple = op->value.as<CallNode>();
431433
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
@@ -850,7 +852,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
850852

851853
Stmt VisitStmt_(const AttrStmtNode* op) final {
852854
Stmt stmt = StmtExprMutator::VisitStmt_(op);
853-
if (op->attr_key == attr::realize_scope) {
855+
if (op->attr_key == tir::attr::realize_scope) {
854856
auto node = op->node.as<te::OperationNode>();
855857
if (node != nullptr) {
856858
if (!frag_reg_.count(node->name)) {
@@ -1186,9 +1188,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
11861188
int warp_threads_y_{-1};
11871189
};
11881190

1189-
Stmt RewriteForTensorCore(Stmt stmt,
1190-
Schedule schedule,
1191-
Map<Tensor, Buffer> extern_buffer) {
1191+
Stmt SchedulePostProcRewriteForTensorCore(
1192+
Stmt stmt,
1193+
Schedule schedule,
1194+
Map<Tensor, Buffer> extern_buffer) {
11921195
// Check if current lower target is CUDA
11931196
auto target = tvm::Target::Current(true);
11941197
if (target.defined() && target->target_name != "cuda") {
@@ -1223,5 +1226,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
12231226
return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
12241227
}
12251228

1226-
} // namespace tir
1229+
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
1230+
.set_body_typed([](Stmt stmt,
1231+
Schedule schedule,
1232+
Map<te::Tensor, Buffer> extern_buffer) {
1233+
return SchedulePostProcRewriteForTensorCore(
1234+
stmt, schedule, extern_buffer);
1235+
});
1236+
1237+
} // namespace te
12271238
} // namespace tvm

src/tir/pass/ffi_api.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
7575
}
7676
});
7777

78-
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
79-
.set_body_typed
80-
([](const Stmt& stmt,
81-
const te::Schedule& schedule,
82-
const Map<te::Tensor, Buffer>& extern_buffer) {
83-
return RewriteForTensorCore(stmt, schedule, extern_buffer);
84-
});
85-
8678
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
8779
.set_body([](TVMArgs args, TVMRetValue *ret) {
8880
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());

vta/python/vta/ir_pass.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in):
638638
selects = []
639639

640640
def _find_basics(op):
641-
if isinstance(op, tvm.tir.Call):
641+
if isinstance(op, tvm.tir.BufferLoad):
642642
calls.append(op)
643643
elif isinstance(op, tvm.tir.Select):
644644
selects.append(op)
@@ -664,18 +664,18 @@ def _do_fold(op):
664664
body = op.body.body
665665
while isinstance(body, tvm.tir.IfThenElse):
666666
body = body.then_case
667-
args = body.args
668-
res_tensor = body.func.output(0)
667+
args = body.indices
668+
res_buffer = body.buffer
669669
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
670670
inner = tvm.tir.AttrStmt(
671-
[dout, res_tensor], 'buffer_bind_scope',
671+
[dout, res_buffer], 'buffer_bind_scope',
672672
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
673673
return inner
674674
else:
675675
conv_call, data_call, kernel_call = calls[-3:]
676-
pad_data_tensor = data_call.func.output(0)
677-
kernel_tensor = kernel_call.func.output(0)
678-
res_tensor = conv_call.func.output(0)
676+
pad_data_tensor = data_call.buffer
677+
kernel_tensor = kernel_call.buffer
678+
res_tensor = conv_call.buffer
679679

680680
if selects:
681681
condition = selects[0].condition
@@ -696,19 +696,19 @@ def _do_fold(op):
696696
0, 0, 0))
697697
inner = irb.get()
698698

699-
args = conv_call.args
699+
args = conv_call.indices
700700
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
701701
1, 0, 1, 0, env.BLOCK_OUT)
702702
inner = tvm.tir.AttrStmt(
703703
[dout, res_tensor], 'buffer_bind_scope',
704704
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
705-
args = kernel_call.args
705+
args = kernel_call.indices
706706
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
707707
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
708708
inner = tvm.tir.AttrStmt(
709709
[dwgt, kernel_tensor], 'buffer_bind_scope',
710710
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
711-
args = data_call.args
711+
args = data_call.indices
712712
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
713713
1, 0, 1, 0, env.BLOCK_IN)
714714
inner = tvm.tir.AttrStmt(

0 commit comments

Comments
 (0)