@@ -21,6 +21,15 @@ namespace tl {
2121
2222using namespace tir ;
2323
24+ /* *
25+ * @brief Extracts a numeric architecture identifier from a Target's "arch" attribute.
26+ *
27+ * Reads the Target's "arch" string (must be defined) and, if it has the form "sm_<N>",
28+ * parses and returns N as an integer. For any other arch string, returns 0.
29+ *
30+ * @param target Target whose "arch" attribute will be inspected (ICHECKs that the attribute is defined).
31+ * @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
32+ */
2433static int GetArchInt (Target target) {
2534 int arch_int = 0 ;
2635 auto s = target->GetAttr <String>(" arch" );
@@ -34,6 +43,24 @@ static int GetArchInt(Target target) {
3443 return arch_int;
3544}
3645
46+ /* *
47+ * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
48+ *
49+ * Builds the internal AtomicAddNode, extracts the source and destination regions and
50+ * their backing Buffers from the first two call-style expressions in `args` (via RegionOp),
51+ * and stores them along with their ranges. If a third argument is provided, it is
52+ * interpreted as an integer immediate and stored as the node's coalesced width.
53+ *
54+ * @param args Call-style PrimExprs where:
55+ * - args[0] is the source region call,
56+ * - args[1] is the destination region call,
57+ * - args[2] (optional) is an IntImm specifying coalesced width.
58+ * @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
59+ *
60+ * Notes:
61+ * - The constructor checks that args[0] and args[1] are CallNodes.
62+ * - The constructed node is stored in this->data_.
63+ */
3764AtomicAdd::AtomicAdd (Array<PrimExpr> args, BufferMap vmap) {
3865 ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
3966 Array<Range> rgs[2 ];
@@ -54,6 +81,15 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
5481 data_ = std::move (node);
5582}
5683
84+ /* *
85+ * @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
86+ *
87+ * Produces a new AtomicAddNode object copied from this node. If this node has an
88+ * associated ParallelOp (par_op_), the parallel op is cloned and attached to
89+ * the new node so the cloned operator preserves parallelization state.
90+ *
91+ * @return TileOperator A TileOperator owning the cloned AtomicAddNode.
92+ */
5793TileOperator AtomicAddNode::Clone () const {
5894 auto op = make_object<AtomicAddNode>(*this );
5995 if (par_op_.defined ()) {
@@ -62,6 +98,16 @@ TileOperator AtomicAddNode::Clone() const {
6298 return AtomicAdd (op);
6399}
64100
101+ /* *
102+ * @brief Create data-parallel iteration variables for non-singleton dimensions of the source.
103+ *
104+ * Constructs an Array of IterVar corresponding to each dimension in `src_range` whose extent is
105+ * not equal to 1. Each IterVar has domain Range(0, extent), a Var named sequentially ("i", "j",
106+ * "k", ...) with the same dtype as the extent, and type IterVarType::kDataPar. The ordering of
107+ * returned itervars matches the order of dimensions in `src_range`.
108+ *
109+ * @return Array<IterVar> Iteration variables for all non-singleton extents in `src_range`.
110+ */
65111Array<IterVar> AtomicAddNode::MakeIterVars () const {
66112 Array<IterVar> loop_vars;
67113 size_t idx = 0 ;
@@ -77,7 +123,22 @@ Array<IterVar> AtomicAddNode::MakeIterVars() const {
77123}
78124
79125// ivs: itervars returned by MakeIterVars()
80- // src_dst: 0 for src_indices, 1 for dst_indices
126+ /* *
127+ * @brief Build index expressions for either source or destination from loop iter vars.
128+ *
129+ * Given a list of iteration variables that correspond to the non-singleton extents of
130+ * the selected region (source when src_dst == 0, destination when src_dst == 1),
131+ * return an array of index expressions matching the full rank of that region.
132+ * For dimensions with extent == 1, the corresponding index is the range's minimum;
133+ * otherwise the index is `min + ivar`.
134+ *
135+ * @param ivs Iteration variables in order for all non-singleton dimensions of the chosen region.
136+ * @param src_dst Selects which region to index: 0 for source (src_range), 1 for destination (dst_range).
137+ * @return Array<PrimExpr> Index expressions for every dimension of the selected region, in original dimension order.
138+ *
139+ * @note The function checks that the number of provided iter vars equals the number of non-singleton
140+ * extents; it will abort (ICHECK) if they differ.
141+ */
81142Array<PrimExpr> AtomicAddNode::MakeIndices (const Array<IterVar> &ivs,
82143 int src_dst) const {
83144 Array<PrimExpr> indices;
@@ -97,6 +158,31 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
97158 return indices;
98159}
99160
161+ /* *
162+ * @brief Build a combined bound-check predicate for indexed access.
163+ *
164+ * Constructs an AND'd predicate ensuring each non-singleton index (derived from
165+ * `ivs`) stays within [0, extent) for the selected operand (source when
166+ * `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
167+ * range list this produces two conditions:
168+ * - range.min + iv >= 0
169+ * - range.min + iv < extent
170+ *
171+ * Conditions that the analyzer can prove (with symbolic bounds) are omitted.
172+ * If no uncertain conditions remain, an empty PrimExpr is returned.
173+ *
174+ * Note: the function ICHECKs that `extents.size()` equals the number of ranges
175+ * for the selected operand.
176+ *
177+ * @param ivs Iteration variables corresponding to non-singleton extents (order
178+ * matches the non-unit ranges of the chosen operand).
179+ * @param extents Per-dimension upper bounds to check against; must have the
180+ * same size as the selected range list.
181+ * @param src_dst Selects which ranges to validate: 0 => `src_range`, else
182+ * `dst_range`.
183+ * @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
184+ * an empty PrimExpr when no checks are required.
185+ */
100186PrimExpr AtomicAddNode::MakePredicate (arith::Analyzer *analyzer,
101187 const Array<IterVar> &ivs,
102188 Array<PrimExpr> extents,
@@ -128,6 +214,30 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
128214 }
129215}
130216
217+ /* *
218+ * @brief Build a SIMT-style loop nest that performs element-wise atomic additions from src to dst.
219+ *
220+ * Constructs a nested loop (parallelized per iter var) that loads a value from the source buffer,
221+ * optionally casts it to the destination dtype, and performs an extern atomic add into the destination
222+ * buffer address. For scalar (zero-dimensional) operations a trivial serial For with a single
223+ * BufferStore is returned.
224+ *
225+ * The method:
226+ * - Creates iter vars for all non-singleton extents and binds them into the provided analyzer.
227+ * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
228+ * - Computes indexed accesses and emits optional bound predicates; out-of-bounds accesses are
229+ * masked to zero when predicates are uncertain.
230+ * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value), src_value)` call wrapped in
231+ * an Evaluate statement.
232+ * - Wraps the body with a parallel For at each loop level. If `coalesced_width` is defined it is
233+ * attached as the "coalesced_width" annotation on each loop.
234+ *
235+ * Note: This function mutates the analyzer binding state by binding loop variables and may fail
236+ * via ICHECK if internal assumptions about shapes are violated.
237+ *
238+ * @return A nested For loop (parallel loops) implementing the atomic-add kernel. For scalar cases
239+ * a serial For of extent 1 is returned.
240+ */
131241For AtomicAddNode::MakeSIMTLoop (arith::Analyzer *analyzer) const {
132242 Array<IterVar> loop_vars = MakeIterVars ();
133243 bool is_scalar = loop_vars.size () == 0 ;
@@ -191,6 +301,31 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
191301 return Downcast<For>(body);
192302}
193303
304+ /* *
305+ * @brief Lower the atomic-add top-level operator into a parallel, vectorized TIR loop.
306+ *
307+ * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs layout
308+ * inference at multiple levels, partitions the root loop by the provided thread variable,
309+ * vectorizes the thread loop, and returns the final (optionally predicate-guarded) statement.
310+ *
311+ * The lowering pipeline:
312+ * - Build the SIMT loop via MakeSIMTLoop.
313+ * - Fuse parallel loops into a single For and wrap as a ParallelOp.
314+ * - Run layout inference at kCommon, kStrict, and kFree levels using fields from `T`.
315+ * - Obtain the loop layout, partition the root loop with PartitionLoop by `T.thread_var`.
316+ * - Vectorize the partitioned thread loop via VectorizeLoop.
317+ * - If the ParallelOp produced a predicate for `T.thread_var`, return an IfThenElse
318+ * that guards the vectorized loop with that predicate; otherwise return the vectorized loop.
319+ *
320+ * @param T Lowering context whose fields are used:
321+ * - T.target: target architecture for layout inference and lowering decisions.
322+ * - T.thread_var: the Var used to partition the outer loop for thread-level parallelism.
323+ * - T.thread_bounds: bounds associated with the thread dimension (used during partitioning).
324+ * - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used during InferLayout.
325+ * @param analyzer Analyzer used for symbolic reasoning during partitioning and folding (omitted
326+ * from detailed param docs as a common analysis utility).
327+ * @return Stmt A lowered TIR statement representing the parallelized and vectorized atomic-add.
328+ */
194329Stmt AtomicAddNode::Lower (const LowerArgs &T, arith::Analyzer *analyzer) const {
195330 Target target = T.target ;
196331 auto simt_loop = MakeSIMTLoop (analyzer);
@@ -221,6 +356,22 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
221356 return vectorized_thread_loop;
222357}
223358
359+ /* *
360+ * @brief Infer and return the layout map for the atomic add operator.
361+ *
362+ * Constructs a cached ParallelOp (by building the SIMT loop) if not already present,
363+ * validates that local.fragment layouts for src and dst match when both are provided,
364+ * and then delegates layout inference to the underlying ParallelOp.
365+ *
366+ * @param T Layout inference inputs, including an optional mapping of buffers to layouts.
367+ * @param level Inference strictness level.
368+ * @return LayoutMap The inferred layout mapping for buffers used by this operator.
369+ *
370+ * @note This method mutates the AtomicAddNode by creating and storing a ParallelOp
371+ * on first invocation.
372+ * @throws If both src and dst have layouts in `local.fragment` and their fragment
373+ * layouts differ, an ICHECK failure is raised with diagnostic output.
374+ */
224375LayoutMap AtomicAddNode::InferLayout (const LayoutInferArgs &T,
225376 InferLevel level) const {
226377 if (!par_op_.defined ()) {
0 commit comments