Skip to content

Commit 6ceb5e0

Browse files
committed
refactor
1 parent f09e91e commit 6ceb5e0

File tree

3 files changed

+107
-140
lines changed

3 files changed

+107
-140
lines changed

src/op/atomic_add.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "../target/utils.h"
1414
#include "../transform/atomicadd_vectorize.h"
1515
#include "../transform/common/loop_fusion_utils.h"
16-
#include "../transform/loop_partition.h"
1716
#include "builtin.h"
1817

1918
namespace tvm {
@@ -362,18 +361,15 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
362361
level);
363362
}
364363
auto loop_layout = par_op->GetLoopLayout();
365-
Var thread_var = T.thread_var;
366-
Range thread_bounds = T.thread_bounds;
367-
auto thread_loop =
368-
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
369-
auto vectorized_thread_loop = VectorizeAtomicAdd(
370-
thread_loop, thread_var, thread_bounds, GetArchInt(target));
364+
auto vectorized_thread_loop =
365+
VectorizeAtomicAdd(par_op->GetRoot(), T.thread_var, T.thread_bounds,
366+
GetArchInt(target), analyzer, loop_layout);
371367

372368
if (par_op->GetPredicate(T.thread_var).defined()) {
373369
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
374370
vectorized_thread_loop);
375371
}
376-
372+
LOG(INFO) << vectorized_thread_loop;
377373
return vectorized_thread_loop;
378374
}
379375

src/transform/atomicadd_vectorize.cc

Lines changed: 101 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "../layout/layout.h"
77
#include "../layout/utils.h"
8+
#include "../transform/loop_partition.h"
89
#include "arith/int_operator.h"
910
#include "arith/ir_visitor_with_analyzer.h"
1011
#include "common/loop_vectorization_utils.h"
@@ -29,6 +30,30 @@ struct AtomicAddVectorizePlanResult {
2930
PrimExpr condition;
3031
};
3132

33+
class BufferIndiceSimplify : public StmtExprMutator {
34+
public:
35+
BufferIndiceSimplify(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
36+
37+
private:
38+
PrimExpr VisitExpr_(const BufferLoadNode *node) final {
39+
auto visited = StmtExprMutator::VisitExpr_(node);
40+
auto n = Downcast<BufferLoad>(visited);
41+
auto nptr = n.CopyOnWrite();
42+
nptr->indices = nptr->indices.Map(
43+
[&](const auto &e) { return analyzer_->Simplify(e); });
44+
return n;
45+
}
46+
Stmt VisitStmt_(const BufferStoreNode *node) final {
47+
auto visited = StmtExprMutator::VisitStmt_(node);
48+
auto n = Downcast<BufferStore>(visited);
49+
auto nptr = n.CopyOnWrite();
50+
nptr->indices = nptr->indices.Map(
51+
[&](const auto &e) { return analyzer_->Simplify(e); });
52+
return n;
53+
}
54+
arith::Analyzer *analyzer_;
55+
};
56+
3257
class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
3358
public:
3459
AtomicAddVectorizePlanner() = default;
@@ -137,69 +162,75 @@ class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
137162
class AtomicAddVectorizeRewriter : public StmtExprMutator {
138163
public:
139164
AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan,
140-
Var thread_var, PrimExpr by_var, PrimExpr bx_var,
141-
const Range &thread_bounds, int stride_y,
142-
int stride_x)
165+
Var thread_var, const Range &thread_bounds)
143166
: vector_size_(plan.vector_size), condition_(plan.condition),
144-
dynamic_(plan.dynamic), tx_var_(std::move(thread_var)),
145-
by_var_(std::move(by_var)), bx_var_(std::move(bx_var)),
146-
stride_y_(stride_y), stride_x_(stride_x) {
167+
dynamic_(plan.dynamic), tx_var_(std::move(thread_var)) {
147168
const int64_t *tx_ext = as_const_int(thread_bounds->extent);
148169
ICHECK(tx_ext)
149170
<< "thread_bounds->extent must be a constant for vectorization.";
150171
extent_tx_ = static_cast<int>(*tx_ext);
151172
}
152173

153-
private:
154-
/**
155-
* @brief Visits a For node and rewrites the innermost loop for atomic-add
156-
* vectorization.
157-
*
158-
* If the visited For node is the recorded innermost loop, this method
159-
* validates that the loop extent is a constant, divisible by the planned
160-
* vector size, and has a zero minimum. When vectorization is enabled
161-
* (dynamic_ == false) it:
162-
* - locates the thread index variable named "tx" inside the loop body,
163-
* - creates a new outer loop variable named "<old_loop_var>_outer",
164-
* - substitutes occurrences of `tx` with `tx * vector_size_` and the old
165-
* loop var with `outer_var * vector_size_` so each outer iteration maps to a
166-
* contiguous vector-sized chunk,
167-
* - returns a new For with extent divided by vector_size_ and the
168-
* transformed body.
169-
*
170-
* If dynamic_ is true, the method returns the (possibly mutated) inner For
171-
* unchanged.
172-
*
173-
* Side effects:
174-
* - updates inner_for_ to point to the current For node during visitation.
175-
* - performs runtime checks (ICHECK) to enforce: constant extent, extent %
176-
* vector_size_ == 0, and zero loop minimum; violations terminate execution.
177-
*
178-
* @return The original or transformed For statement as a Stmt.
179-
*/
180-
Stmt VisitStmt_(const ForNode *node) final {
181-
inner_for_ = node;
182-
iter_var_ = Var(node->loop_var->name_hint + "_outer");
183-
auto ret = StmtExprMutator::VisitStmt_(node);
184-
if (inner_for_ == node) { // rewrite the innermost loop
185-
For fnode = ret.as<For>().value();
186-
auto extent_ptr = as_const_int(fnode->extent);
187-
ICHECK(extent_ptr) << fnode->extent;
188-
int extent = *extent_ptr;
189-
ICHECK(extent % vector_size_ == 0)
190-
<< "extent: " << extent << " vector_size_: " << vector_size_;
191-
ICHECK(is_zero(fnode->min));
192-
if (!dynamic_) {
193-
Map<Var, PrimExpr> vmap;
194-
vmap.Set(fnode->loop_var, iter_var_);
195-
Stmt body = Substitute(fnode->body, vmap);
196-
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body,
197-
fnode->thread_binding, fnode->annotations, fnode->span);
198-
}
174+
For run(For for_node, const Fragment &loop_layout,
175+
arith::Analyzer *analyzer) {
176+
int old_loop_depth = loop_layout->InputDim();
177+
int new_loop_depth = loop_layout->OutputDim();
178+
179+
Array<Var> vars;
180+
for (int i = 0; i < new_loop_depth; i++) {
181+
Var var = Var(std::string{char('i' + i)});
182+
vars.push_back(var);
183+
}
184+
vars.push_back(tx_var_);
185+
Map<Var, PrimExpr> vmap;
186+
Stmt body = std::move(for_node);
187+
auto inv_loop = loop_layout->Inverse();
188+
auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
189+
// the innerest iter_var need expand because of vectorize
190+
191+
const ForNode *loop = body.as<ForNode>();
192+
ICHECK(loop != nullptr);
193+
vmap.Set(loop->loop_var, indices[0] * vector_size_);
194+
body = loop->body;
195+
for (int i = 1; i < old_loop_depth; i++) {
196+
const ForNode *loop = body.as<ForNode>();
197+
ICHECK(loop != nullptr);
198+
vmap.Set(loop->loop_var, indices[i]);
199+
body = loop->body;
199200
}
200-
return ret;
201+
body = Substitute(body, vmap);
202+
203+
// innerest iter_var extent need to be shorter because of vectorize
204+
205+
body = For(vars[new_loop_depth - 1],
206+
make_zero(vars[new_loop_depth - 1]->dtype),
207+
div(inv_loop->InputShape()[new_loop_depth - 1], vector_size_),
208+
ForKind::kSerial, body);
209+
analyzer->Bind(vars[new_loop_depth - 1],
210+
Range(0, div(inv_loop->InputShape()[new_loop_depth - 1],
211+
vector_size_)));
212+
213+
for (int i = new_loop_depth - 2; i >= 0; i--) {
214+
body = For(vars[i], make_zero(vars[i]->dtype),
215+
div(inv_loop->InputShape()[i], vector_size_), ForKind::kSerial,
216+
body);
217+
analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i]));
218+
}
219+
220+
body = BufferIndiceSimplify(analyzer)(body);
221+
222+
auto node = LoopPragmaUnroll(Downcast<For>(body));
223+
if (loop_layout->ThreadRange().defined()) {
224+
auto range = loop_layout->ThreadRange();
225+
auto thread_var_with_offset = tx_var_ - range->min;
226+
node.CopyOnWrite()->body =
227+
Substitute(node->body, {{tx_var_, thread_var_with_offset}});
228+
}
229+
auto new_stmt = this->VisitStmt(node);
230+
return Downcast<For>(new_stmt);
201231
}
202232

233+
private:
203234
PrimExpr VisitExpr_(const CallNode *node) final {
204235
if (dynamic_) {
205236
return StmtExprMutator::VisitExpr_(node);
@@ -208,57 +239,18 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
208239
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
209240
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
210241
if (func_name->value == "AtomicAdd") {
211-
// Matrix[by * stride_y + i / (stride_x / (tx_txtent *
212-
// vector_size_)) + tx_var_ / (stride_x / vector_size_),
213-
// bx * stride_x + (i % (stride_x / (tx_extent *
214-
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
215-
// (stride / vector_size_)) * vector_size_]
216-
const BufferLoadNode *old_dst_node =
242+
const BufferLoadNode *temp_dst_node =
217243
node->args[1].as<BufferLoadNode>();
218-
const BufferLoadNode *old_value_node =
244+
const BufferLoadNode *temp_value_node =
219245
node->args[2].as<BufferLoadNode>();
220-
if (!old_dst_node || !old_value_node) {
246+
if (!temp_dst_node || !temp_value_node) {
221247
return StmtExprMutator::VisitExpr_(node);
222248
}
223-
Array<PrimExpr> dst_indices, value_indices;
224-
if ((extent_tx_ * vector_size_) > stride_x_) {
225-
dst_indices.push_back(
226-
by_var_ * stride_y_ +
227-
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
228-
truncdiv(tx_var_, stride_x_ / vector_size_));
229-
dst_indices.push_back(
230-
bx_var_ * stride_x_ +
231-
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
232-
value_indices.push_back(
233-
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
234-
truncdiv(tx_var_ * vector_size_, stride_x_));
235-
value_indices.push_back(
236-
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
237-
} else {
238-
dst_indices.push_back(
239-
by_var_ * stride_y_ +
240-
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
241-
truncdiv(tx_var_, stride_x_ / vector_size_));
242-
dst_indices.push_back(
243-
bx_var_ * stride_x_ +
244-
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
245-
(extent_tx_ * vector_size_) +
246-
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
247-
value_indices.push_back(
248-
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
249-
truncdiv(tx_var_, stride_x_ / vector_size_));
250-
value_indices.push_back(
251-
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
252-
(extent_tx_ * vector_size_) +
253-
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
254-
}
249+
const BufferLoad dst_node =
250+
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
251+
const BufferLoad value_node =
252+
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
255253

256-
BufferLoad dst_node =
257-
BufferLoad(old_dst_node->buffer, dst_indices,
258-
old_dst_node->predicate, old_dst_node->span);
259-
BufferLoad value_node =
260-
BufferLoad(old_value_node->buffer, value_indices,
261-
old_value_node->predicate, old_value_node->span);
262254
Call address_of_dst =
263255
Call(DataType::Handle(), builtin::address_of(), {dst_node});
264256
Call address_of_value =
@@ -287,10 +279,7 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
287279
const int vector_size_;
288280
const PrimExpr condition_;
289281
const bool dynamic_;
290-
const PrimExpr by_var_, bx_var_;
291-
int stride_y_, stride_x_;
292282
const Var tx_var_;
293-
Var iter_var_;
294283
int extent_tx_;
295284
};
296285

@@ -317,11 +306,10 @@ static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
317306
}
318307

319308
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
320-
const Range &thread_bounds, int compute_capability) {
309+
const Range &thread_bounds, int compute_capability,
310+
arith::Analyzer *analyzer, const Fragment &loop_layout) {
321311

322312
int vectorize_size_max = 1;
323-
int stride_x = -1, stride_y = -1;
324-
PrimExpr bx_var, by_var;
325313

326314
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
327315
if (const auto *call = obj.as<CallNode>()) {
@@ -333,40 +321,22 @@ For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
333321
}
334322
}
335323
}
336-
if (const MulNode *mul = obj.as<MulNode>()) {
337-
const VarNode *var = nullptr;
338-
const IntImmNode *imm = nullptr;
339-
PrimExpr var_expr;
340-
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
341-
var_expr = mul->a;
342-
} else if ((var = mul->b.as<VarNode>()) &&
343-
(imm = mul->a.as<IntImmNode>())) {
344-
var_expr = mul->b;
345-
}
346-
if (var && imm) {
347-
if (var->name_hint == "bx") {
348-
stride_x = imm->value;
349-
bx_var = var_expr;
350-
} else if (var->name_hint == "by") {
351-
stride_y = imm->value;
352-
by_var = var_expr;
353-
}
354-
}
355-
}
356324
});
325+
357326
if (vectorize_size_max != 1) {
358327
int vectorize_hint = vectorize_size_max;
359328
AtomicAddVectorizePlanResult res = {1, false, 0};
360329
AtomicAddVectorizePlanner planner;
361-
res = planner.Plan(for_node, thread_var, thread_bounds, vectorize_hint);
330+
For simplified_for_node =
331+
PartitionLoop(for_node, thread_var, analyzer, loop_layout);
332+
res = planner.Plan(simplified_for_node, thread_var, thread_bounds,
333+
vectorize_hint);
362334
vectorize_hint = res.vector_size;
363335

364-
if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
365-
!bx_var.defined() || !by_var.defined())
336+
if (vectorize_hint == 1)
366337
return for_node;
367-
auto rewriter = AtomicAddVectorizeRewriter(
368-
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);
369-
return Downcast<For>(rewriter(for_node));
338+
auto rewriter = AtomicAddVectorizeRewriter(res, thread_var, thread_bounds);
339+
return rewriter.run(for_node, loop_layout, analyzer);
370340
} else {
371341
return for_node;
372342
}

src/transform/atomicadd_vectorize.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace tl {
1515
using namespace tir;
1616

1717
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
18-
const Range &thread_bounds, int compute_capability);
18+
const Range &thread_bounds, int compute_capability,
19+
arith::Analyzer *analyzer, const Fragment &loop_layout);
1920

2021
} // namespace tl
2122
} // namespace tvm

0 commit comments

Comments
 (0)