@@ -226,6 +226,34 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
226226 }
227227}
228228
229+ /* *
230+ * @brief Build a SIMT-style loop nest that performs element-wise atomic
231+ * additions from src to dst.
232+ *
233+ * Constructs a nested loop (parallelized per iter var) that loads a value from
234+ * the source buffer, optionally casts it to the destination dtype, and performs
235+ * an extern atomic add into the destination buffer address. For scalar
236+ * (zero-dimensional) operations a trivial serial For with a single BufferStore
237+ * is returned.
238+ *
239+ * The method:
240+ * - Creates iter vars for all non-singleton extents and binds them into the
241+ * provided analyzer.
242+ * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
243+ * - Computes indexed accesses and emits optional bound predicates;
244+ * out-of-bounds accesses are masked to zero when predicates are uncertain.
245+ * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
246+ * src_value)` call wrapped in an Evaluate statement.
247+ * - Wraps the body with a parallel For at each loop level. If `coalesced_width`
248+ * is defined it is attached as the "coalesced_width" annotation on each loop.
249+ *
250+ * Note: This function mutates the analyzer binding state by binding loop
251+ * variables and may fail via ICHECK if internal assumptions about shapes are
252+ * violated.
253+ *
254+ * @return A nested For loop (parallel loops) implementing the atomic-add
255+ * kernel. For scalar cases a serial For of extent 1 is returned.
256+ */
229257For AtomicAddNode::MakeSIMTLoop (arith::Analyzer *analyzer) const {
230258 Array<IterVar> loop_vars = MakeIterVars ();
231259 bool is_scalar = loop_vars.empty ();
@@ -286,70 +314,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
286314 return Downcast<For>(body);
287315}
288316
289- /* *
290- * @brief Lower the atomic-add top-level operator into a parallel, vectorized
291- * TIR loop.
292- *
293- * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
294- * layout inference at multiple levels, partitions the root loop by the provided
295- * thread variable, vectorizes the thread loop, and returns the final
296- * (optionally predicate-guarded) statement.
297- *
298- * The lowering pipeline:
299- * - Build the SIMT loop via MakeSIMTLoop.
300- * - Fuse parallel loops into a single For and wrap as a ParallelOp.
301- * - Run layout inference at kCommon, kStrict, and kFree levels using fields
302- * from `T`.
303- * - Obtain the loop layout, partition the root loop with PartitionLoop by
304- * `T.thread_var`.
305- * - Vectorize the partitioned thread loop via VectorizeLoop.
306- * - If the ParallelOp produced a predicate for `T.thread_var`, return an
307- * IfThenElse that guards the vectorized loop with that predicate; otherwise
308- * return the vectorized loop.
309- *
310- * @param T Lowering context whose fields are used:
311- * - T.target: target architecture for layout inference and lowering
312- * decisions.
313- * - T.thread_var: the Var used to partition the outer loop for thread-level
314- * parallelism.
315- * - T.thread_bounds: bounds associated with the thread dimension (used during
316- * partitioning).
317- * - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
318- * during InferLayout.
319- * @param analyzer Analyzer used for symbolic reasoning during partitioning and
320- * folding (omitted from detailed param docs as a common analysis utility).
321- * @return Stmt A lowered TIR statement representing the parallelized and
322- * vectorized atomic-add.
323- */
324- Stmt AtomicAddNode::Lower (const LowerArgs &T, arith::Analyzer *analyzer) const {
325- Target target = T.target ;
326- auto simt_loop = MakeSIMTLoop (analyzer);
327- auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse (simt_loop));
328- auto par_op = ParallelOp (fused_loop);
329-
330- std::vector<InferLevel> levels = {InferLevel::kCommon , InferLevel::kStrict ,
331- InferLevel::kFree };
332- for (auto level : levels) {
333- (par_op)->InferLayout ({T.target , T.thread_bounds , T.layout_map , analyzer,
334- false , T.buffer_remap },
335- level);
336- }
337- auto loop_layout = par_op->GetLoopLayout ();
338- Var thread_var = T.thread_var ;
339- Range thread_bounds = T.thread_bounds ;
340- auto thread_loop =
341- PartitionLoop (par_op->GetRoot (), T.thread_var , analyzer, loop_layout);
342- auto vectorized_thread_loop = VectorizeAtomicAdd (
343- thread_loop, thread_var, thread_bounds, GetArchInt (target));
344-
345- if (par_op->GetPredicate (T.thread_var ).defined ()) {
346- return IfThenElse (par_op->GetPredicate (T.thread_var ).value (),
347- vectorized_thread_loop);
348- }
349-
350- return vectorized_thread_loop;
351- }
352-
353317/* *
354318 * @brief Infer and return the layout map for the atomic add operator.
355319 *
@@ -391,6 +355,41 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
391355 return par_op_->InferLayout (T, level);
392356}
393357
358+ /* *
359+ * @brief Lower the atomic-add top-level operator into a parallel, vectorized
360+ * TIR loop.
361+ *
362+ * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
363+ * layout inference at multiple levels, partitions the root loop by the provided
364+ * thread variable, vectorizes the thread loop, and returns the final
365+ * (optionally predicate-guarded) statement.
366+ *
367+ * The lowering pipeline:
368+ * - Build the SIMT loop via MakeSIMTLoop.
369+ * - Fuse parallel loops into a single For and wrap as a ParallelOp.
370+ * - Run layout inference at kCommon, kStrict, and kFree levels using fields
371+ * from `T`.
372+ * - Obtain the loop layout, partition the root loop with PartitionLoop by
373+ * `T.thread_var`.
374+ * - Vectorize the partitioned thread loop via VectorizeLoop.
375+ * - If the ParallelOp produced a predicate for `T.thread_var`, return an
376+ * IfThenElse that guards the vectorized loop with that predicate; otherwise
377+ * return the vectorized loop.
378+ *
379+ * @param T Lowering context whose fields are used:
380+ * - T.target: target architecture for layout inference and lowering
381+ * decisions.
382+ * - T.thread_var: the Var used to partition the outer loop for thread-level
383+ * parallelism.
384+ * - T.thread_bounds: bounds associated with the thread dimension (used during
385+ * partitioning).
386+ * - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
387+ * during InferLayout.
388+ * @param analyzer Analyzer used for symbolic reasoning during partitioning and
389+ * folding (omitted from detailed param docs as a common analysis utility).
390+ * @return Stmt A lowered TIR statement representing the parallelized and
391+ * vectorized atomic-add.
392+ */
394393Stmt AtomicAddNode::Lower (const LowerArgs &T, arith::Analyzer *analyzer) const {
395394 Target target = T.target ;
396395 auto simt_loop = MakeSIMTLoop (analyzer);
@@ -510,7 +509,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
510509 read_src.value (), C.indice_map [read_src.value ()], args.layout_map ,
511510 args.thread_bounds , C.loop_vars );
512511 } else {
513- For remapped = loop; // 简化处理
512+ For remapped = loop;
514513 loop_layout = PlanLoopPartition (remapped, vec, args.thread_bounds );
515514 }
516515
@@ -527,13 +526,10 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
527526 {T.target , T.thread_bounds , T.layout_map ,
528527 analyzer, false , T.buffer_remap });
529528 Fragment loop_layout = ret.loop_layout ;
530- LOG (INFO) << loop_layout->DebugOutput ();
531529 auto thread_loop =
532530 PartitionLoop (transformed_loop, T.thread_var , analyzer, loop_layout);
533- LOG (INFO) << thread_loop;
534531 auto vectorized_thread_loop =
535532 VectorizeAtomicAdd (thread_loop, GetArchInt (target));
536- LOG (INFO) << vectorized_thread_loop;
537533 return vectorized_thread_loop;
538534}
539535
0 commit comments