Skip to content

Commit 12e437c

Browse files
committed
fix
1 parent c317f26 commit 12e437c

File tree

2 files changed

+64
-70
lines changed

2 files changed

+64
-70
lines changed

src/op/atomic_add.cc

Lines changed: 64 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,34 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
226226
}
227227
}
228228

229+
/**
230+
* @brief Build a SIMT-style loop nest that performs element-wise atomic
231+
* additions from src to dst.
232+
*
233+
* Constructs a nested loop (parallelized per iter var) that loads a value from
234+
* the source buffer, optionally casts it to the destination dtype, and performs
235+
* an extern atomic add into the destination buffer address. For scalar
236+
* (zero-dimensional) operations a trivial serial For with a single BufferStore
237+
* is returned.
238+
*
239+
* The method:
240+
* - Creates iter vars for all non-singleton extents and binds them into the
241+
* provided analyzer.
242+
* - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
243+
* - Computes indexed accesses and emits optional bound predicates;
244+
* out-of-bounds accesses are masked to zero when predicates are uncertain.
245+
* - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
246+
* src_value)` call wrapped in an Evaluate statement.
247+
* - Wraps the body with a parallel For at each loop level. If `coalesced_width`
248+
* is defined it is attached as the "coalesced_width" annotation on each loop.
249+
*
250+
* Note: This function mutates the analyzer binding state by binding loop
251+
* variables and may fail via ICHECK if internal assumptions about shapes are
252+
* violated.
253+
*
254+
* @return A nested For loop (parallel loops) implementing the atomic-add
255+
* kernel. For scalar cases a serial For of extent 1 is returned.
256+
*/
229257
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
230258
Array<IterVar> loop_vars = MakeIterVars();
231259
bool is_scalar = loop_vars.empty();
@@ -286,70 +314,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
286314
return Downcast<For>(body);
287315
}
288316

289-
/**
290-
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
291-
* TIR loop.
292-
*
293-
* Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
294-
* layout inference at multiple levels, partitions the root loop by the provided
295-
* thread variable, vectorizes the thread loop, and returns the final
296-
* (optionally predicate-guarded) statement.
297-
*
298-
* The lowering pipeline:
299-
* - Build the SIMT loop via MakeSIMTLoop.
300-
* - Fuse parallel loops into a single For and wrap as a ParallelOp.
301-
* - Run layout inference at kCommon, kStrict, and kFree levels using fields
302-
* from `T`.
303-
* - Obtain the loop layout, partition the root loop with PartitionLoop by
304-
* `T.thread_var`.
305-
* - Vectorize the partitioned thread loop via VectorizeLoop.
306-
* - If the ParallelOp produced a predicate for `T.thread_var`, return an
307-
* IfThenElse that guards the vectorized loop with that predicate; otherwise
308-
* return the vectorized loop.
309-
*
310-
* @param T Lowering context whose fields are used:
311-
* - T.target: target architecture for layout inference and lowering
312-
* decisions.
313-
* - T.thread_var: the Var used to partition the outer loop for thread-level
314-
* parallelism.
315-
* - T.thread_bounds: bounds associated with the thread dimension (used during
316-
* partitioning).
317-
* - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
318-
* during InferLayout.
319-
* @param analyzer Analyzer used for symbolic reasoning during partitioning and
320-
* folding (omitted from detailed param docs as a common analysis utility).
321-
* @return Stmt A lowered TIR statement representing the parallelized and
322-
* vectorized atomic-add.
323-
*/
324-
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
325-
Target target = T.target;
326-
auto simt_loop = MakeSIMTLoop(analyzer);
327-
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
328-
auto par_op = ParallelOp(fused_loop);
329-
330-
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
331-
InferLevel::kFree};
332-
for (auto level : levels) {
333-
(par_op)->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
334-
false, T.buffer_remap},
335-
level);
336-
}
337-
auto loop_layout = par_op->GetLoopLayout();
338-
Var thread_var = T.thread_var;
339-
Range thread_bounds = T.thread_bounds;
340-
auto thread_loop =
341-
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
342-
auto vectorized_thread_loop = VectorizeAtomicAdd(
343-
thread_loop, thread_var, thread_bounds, GetArchInt(target));
344-
345-
if (par_op->GetPredicate(T.thread_var).defined()) {
346-
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
347-
vectorized_thread_loop);
348-
}
349-
350-
return vectorized_thread_loop;
351-
}
352-
353317
/**
354318
* @brief Infer and return the layout map for the atomic add operator.
355319
*
@@ -391,6 +355,41 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
391355
return par_op_->InferLayout(T, level);
392356
}
393357

358+
/**
359+
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
360+
* TIR loop.
361+
*
362+
* Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
363+
* layout inference at multiple levels, partitions the root loop by the provided
364+
* thread variable, vectorizes the thread loop, and returns the final
365+
* (optionally predicate-guarded) statement.
366+
*
367+
* The lowering pipeline:
368+
* - Build the SIMT loop via MakeSIMTLoop.
369+
* - Fuse parallel loops into a single For and wrap as a ParallelOp.
370+
* - Run layout inference at kCommon, kStrict, and kFree levels using fields
371+
* from `T`.
372+
* - Obtain the loop layout, partition the root loop with PartitionLoop by
373+
* `T.thread_var`.
374+
* - Vectorize the partitioned thread loop via VectorizeLoop.
375+
* - If the ParallelOp produced a predicate for `T.thread_var`, return an
376+
* IfThenElse that guards the vectorized loop with that predicate; otherwise
377+
* return the vectorized loop.
378+
*
379+
* @param T Lowering context whose fields are used:
380+
* - T.target: target architecture for layout inference and lowering
381+
* decisions.
382+
* - T.thread_var: the Var used to partition the outer loop for thread-level
383+
* parallelism.
384+
* - T.thread_bounds: bounds associated with the thread dimension (used during
385+
* partitioning).
386+
* - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
387+
* during InferLayout.
388+
* @param analyzer Analyzer used for symbolic reasoning during partitioning and
389+
* folding (omitted from detailed param docs as a common analysis utility).
390+
* @return Stmt A lowered TIR statement representing the parallelized and
391+
* vectorized atomic-add.
392+
*/
394393
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
395394
Target target = T.target;
396395
auto simt_loop = MakeSIMTLoop(analyzer);
@@ -510,7 +509,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
510509
read_src.value(), C.indice_map[read_src.value()], args.layout_map,
511510
args.thread_bounds, C.loop_vars);
512511
} else {
513-
For remapped = loop; // 简化处理
512+
For remapped = loop;
514513
loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
515514
}
516515

@@ -527,13 +526,10 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
527526
{T.target, T.thread_bounds, T.layout_map,
528527
analyzer, false, T.buffer_remap});
529528
Fragment loop_layout = ret.loop_layout;
530-
LOG(INFO) << loop_layout->DebugOutput();
531529
auto thread_loop =
532530
PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
533-
LOG(INFO) << thread_loop;
534531
auto vectorized_thread_loop =
535532
VectorizeAtomicAdd(thread_loop, GetArchInt(target));
536-
LOG(INFO) << vectorized_thread_loop;
537533
return vectorized_thread_loop;
538534
}
539535

src/transform/atomicadd_vectorize.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
3333
}
3434
});
3535

36-
LOG(INFO) << vectorize_size_max; // 4
3736
if (vectorize_size_max <= 1) {
3837
return {1, dynamic_, condition_};
3938
}
@@ -238,7 +237,6 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
238237
AtomicAddVectorizePlanner planner;
239238
res = planner.Plan(for_node, compute_capability);
240239
int vectorize_hint = res.vector_size;
241-
LOG(INFO) << vectorize_hint;
242240
if (vectorize_hint == 1)
243241
return for_node;
244242
auto rewriter = AtomicAddVectorizeRewriter(res);

0 commit comments

Comments
 (0)