Skip to content

Commit d97274c

Browse files
MasterJH5574junrushaozxybazhspectrometerHBHHzfengsy
authored
[MetaSchedule] Post Processor: Rewrite Reduction Block (#10013)
Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org>
1 parent 2a91f0d commit d97274c

File tree

6 files changed

+454
-0
lines changed

6 files changed

+454
-0
lines changed

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
# under the License.
1717
"""The tvm.meta_schedule.postproc package."""
1818
from .postproc import Postproc, PyPostproc
19+
from .rewrite_reduction_block import RewriteReductionBlock
1920
from .verify_gpu_code import VerifyGPUCode
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that rewrites reduction block by moving the init block out."""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
23+
24+
@register_object("meta_schedule.RewriteReductionBlock")
25+
class RewriteReductionBlock(Postproc):
26+
"""A postprocessor that rewrites reduction block by moving the init block out."""
27+
28+
def __init__(self) -> None:
29+
self.__init_handle_by_constructor__(
30+
_ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member
31+
)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
/*! \brief The visitor that finds all the reduction block to be decomposed */
25+
struct ReductionBlockFinder : private StmtVisitor {
26+
public:
27+
/*! \brief Find all the reduction blocks that should be decomposed */
28+
static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) {
29+
std::vector<std::pair<StmtSRef, String>> results;
30+
for (const auto& kv : self->mod->functions) {
31+
GlobalVar g_var = kv.first;
32+
BaseFunc base_func = kv.second;
33+
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
34+
ReductionBlockFinder finder;
35+
finder(prim_func->body);
36+
for (const BlockNode* block : finder.results_) {
37+
results.emplace_back(self->stmt2ref.at(block), g_var->name_hint);
38+
}
39+
}
40+
}
41+
return results;
42+
}
43+
44+
private:
45+
void VisitStmt_(const ForNode* loop) final {
46+
runtime::ThreadScope thread_scope = GetThreadScope(loop);
47+
if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) {
48+
thread_bound_loop_vars_.insert(loop->loop_var.get());
49+
}
50+
StmtVisitor::VisitStmt_(loop);
51+
}
52+
53+
void VisitStmt_(const BlockRealizeNode* realize) final {
54+
if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) {
55+
results_.push_back(realize->block.get());
56+
}
57+
StmtVisitor::VisitStmt_(realize);
58+
}
59+
60+
bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const {
61+
if (thread_bound_loop_vars_.empty()) {
62+
return true;
63+
}
64+
auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); };
65+
const BlockNode* block = realize->block.get();
66+
ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
67+
int n = block->iter_vars.size();
68+
for (int i = 0; i < n; ++i) {
69+
IterVar iter_var = block->iter_vars[i];
70+
PrimExpr binding = realize->iter_values[i];
71+
if (iter_var->iter_type == tir::kCommReduce) {
72+
if (UsesVar(binding, f_find)) {
73+
return false;
74+
}
75+
}
76+
}
77+
return true;
78+
}
79+
80+
/*! \brief The results of the collection */
81+
std::vector<const BlockNode*> results_;
82+
/*! \brief Loop variables that are bound to threads */
83+
std::unordered_set<const VarNode*> thread_bound_loop_vars_;
84+
};
85+
86+
/*!
87+
* \brief Find the innermost loop that the `init` of the input block could be decomposed to
88+
* \param block_sref The StmtSRef of the block to be decomposed
89+
* \return The index of the innermost loop where the `init` of the input block could be decomposed,
90+
* or -1 if the `init` does not need to be decomposed.
91+
*/
92+
int FindDecomposePoint(const StmtSRef& block_sref) {
93+
Array<StmtSRef> loop_srefs = GetLoops(block_sref);
94+
int n = loop_srefs.size();
95+
for (int i = 0; i < n; ++i) {
96+
if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) {
97+
return i;
98+
}
99+
}
100+
return -1;
101+
}
102+
103+
} // namespace tir
104+
} // namespace tvm
105+
106+
namespace tvm {
107+
namespace meta_schedule {
108+
109+
/*! \brief Rewrite reduction block by moving the init block out */
110+
class RewriteReductionBlockNode : public PostprocNode {
111+
public:
112+
// Inherited from PostprocNode
113+
void InitializeWithTuneContext(const TuneContext& context) final {}
114+
// Inherited from PostprocNode
115+
bool Apply(const tir::Schedule& sch) final;
116+
117+
void VisitAttrs(tvm::AttrVisitor* v) {}
118+
119+
static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock";
120+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode);
121+
};
122+
123+
bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
124+
for (;;) {
125+
std::vector<std::pair<tir::StmtSRef, String>> results =
126+
tir::ReductionBlockFinder::Find(sch->state());
127+
int rewritten = 0;
128+
for (const auto& kv : results) {
129+
const tir::StmtSRef& block_sref = kv.first;
130+
const String& global_var_name = kv.second;
131+
int decompose_point = tir::FindDecomposePoint(block_sref);
132+
if (decompose_point == -1) {
133+
continue;
134+
}
135+
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
136+
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
137+
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);
138+
++rewritten;
139+
}
140+
if (rewritten == 0) {
141+
break;
142+
}
143+
}
144+
return true;
145+
}
146+
147+
Postproc Postproc::RewriteReductionBlock() {
148+
ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>();
149+
return Postproc(n);
150+
}
151+
152+
TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode);
153+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock")
154+
.set_body_typed(Postproc::RewriteReductionBlock);
155+
156+
} // namespace meta_schedule
157+
} // namespace tvm

src/meta_schedule/utils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,19 @@ inline std::string Concat(const Array<String>& strs, const std::string& delim) {
246246
return os.str();
247247
}
248248

249+
/*!
250+
* \brief Get the BlockRV from a block StmtSRef
251+
* \param sch The schedule
252+
* \param block_sref The block StmtSRef
253+
* \param global_var_name The global variable name
254+
* \return The BlockRV
255+
*/
256+
inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref,
257+
const String& global_var_name) {
258+
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
259+
return sch->GetBlock(block->name_hint, global_var_name);
260+
}
261+
249262
/*!
250263
* \brief A helper data structure that replays a trace and collects failure counts
251264
* for each postprocessor

src/tir/schedule/utils.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,36 @@ inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_va
192192
Var(std::move(name), loop->loop_var.dtype()), iter_var_type);
193193
}
194194

195+
/*!
196+
* \brief Get the thread scope bound to the specific loop
197+
* \param loop The loop to be inspected
198+
* \return The thread scope bound to the loop
199+
*/
200+
inline runtime::ThreadScope GetThreadScope(const ForNode* loop) {
201+
if (loop->kind == ForKind::kThreadBinding) {
202+
return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
203+
}
204+
return runtime::ThreadScope{-1, -1};
205+
}
206+
207+
/*!
208+
* \brief Check if the thread scope is blockIdx
209+
* \param thread_scope The thread scope to be checked
210+
* \return True if the thread scope is blockIdx
211+
*/
212+
inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) {
213+
return thread_scope.rank == 0; // The rank of blockIdx is 0
214+
}
215+
216+
/*!
217+
* \brief Check if the thread scope is threadIdx
218+
* \param thread_scope The thread scope to be checked
219+
* \return True if the thread scope is threadIdx
220+
*/
221+
inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) {
222+
return thread_scope.rank == 1 && thread_scope.dim_index >= 0;
223+
}
224+
195225
/******** Integer set ********/
196226

197227
/*!

0 commit comments

Comments
 (0)