Skip to content

Commit a7fcb81

Browse files
📝 Add docstrings to pytile_0826
Docstrings generation was requested by @LeiWang1999. * #763 (comment) The following files were modified: * `src/op/atomic_add.cc` * `src/op/atomic_add.h` * `src/op/copy.cc` * `src/op/copy.h` * `src/op/elem.cc` * `src/op/elem.h` * `src/op/gemm.cc` * `src/op/gemm.h` * `src/op/gemm_sp.cc` * `src/op/gemm_sp.h` * `src/op/operator.cc` * `src/op/operator.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/op/reduce.h` * `src/op/region.cc` * `src/op/region.h` * `src/transform/layout_inference.cc` * `src/transform/lower_tile_op.cc`
1 parent b38bd69 commit a7fcb81

File tree

20 files changed

+2042
-142
lines changed

20 files changed

+2042
-142
lines changed

src/op/atomic_add.cc

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ namespace tl {
2121

2222
using namespace tir;
2323

24+
/**
25+
* @brief Extracts a numeric architecture identifier from a Target's "arch" attribute.
26+
*
27+
* Reads the Target's "arch" string (must be defined) and, if it has the form "sm_<N>",
28+
* parses and returns N as an integer. For any other arch string, returns 0.
29+
*
30+
* @param target Target whose "arch" attribute will be inspected (ICHECKs that the attribute is defined).
31+
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
32+
*/
2433
static int GetArchInt(Target target) {
2534
int arch_int = 0;
2635
auto s = target->GetAttr<String>("arch");
@@ -34,6 +43,24 @@ static int GetArchInt(Target target) {
3443
return arch_int;
3544
}
3645

46+
/**
47+
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
48+
*
49+
* Builds the internal AtomicAddNode, extracts the source and destination regions and
50+
* their backing Buffers from the first two call-style expressions in `args` (via RegionOp),
51+
* and stores them along with their ranges. If a third argument is provided, it is
52+
* interpreted as an integer immediate and stored as the node's coalesced width.
53+
*
54+
* @param args Call-style PrimExprs where:
55+
* - args[0] is the source region call,
56+
* - args[1] is the destination region call,
57+
* - args[2] (optional) is an IntImm specifying coalesced width.
58+
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
59+
*
60+
* Notes:
61+
* - The constructor checks that args[0] and args[1] are CallNodes.
62+
* - The constructed node is stored in this->data_.
63+
*/
3764
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
3865
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
3966
Array<Range> rgs[2];
@@ -54,6 +81,15 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
5481
data_ = std::move(node);
5582
}
5683

84+
/**
85+
* @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
86+
*
87+
* Produces a new AtomicAddNode object copied from this node. If this node has an
88+
* associated ParallelOp (par_op_), the parallel op is cloned and attached to
89+
* the new node so the cloned operator preserves parallelization state.
90+
*
91+
* @return TileOperator A TileOperator owning the cloned AtomicAddNode.
92+
*/
5793
TileOperator AtomicAddNode::Clone() const {
5894
auto op = make_object<AtomicAddNode>(*this);
5995
if (par_op_.defined()) {
@@ -62,6 +98,16 @@ TileOperator AtomicAddNode::Clone() const {
6298
return AtomicAdd(op);
6399
}
64100

101+
/**
102+
* @brief Create data-parallel iteration variables for non-singleton dimensions of the source.
103+
*
104+
* Constructs an Array of IterVar corresponding to each dimension in `src_range` whose extent is
105+
* not equal to 1. Each IterVar has domain Range(0, extent), a Var named sequentially ("i", "j",
106+
* "k", ...) with the same dtype as the extent, and type IterVarType::kDataPar. The ordering of
107+
* returned itervars matches the order of dimensions in `src_range`.
108+
*
109+
* @return Array<IterVar> Iteration variables for all non-singleton extents in `src_range`.
110+
*/
65111
Array<IterVar> AtomicAddNode::MakeIterVars() const {
66112
Array<IterVar> loop_vars;
67113
size_t idx = 0;
@@ -77,7 +123,22 @@ Array<IterVar> AtomicAddNode::MakeIterVars() const {
77123
}
78124

79125
// ivs: itervars returned by MakeIterVars()
80-
// src_dst: 0 for src_indices, 1 for dst_indices
126+
/**
127+
* @brief Build index expressions for either source or destination from loop iter vars.
128+
*
129+
* Given a list of iteration variables that correspond to the non-singleton extents of
130+
* the selected region (source when src_dst == 0, destination when src_dst == 1),
131+
* return an array of index expressions matching the full rank of that region.
132+
* For dimensions with extent == 1, the corresponding index is the range's minimum;
133+
* otherwise the index is `min + ivar`.
134+
*
135+
* @param ivs Iteration variables in order for all non-singleton dimensions of the chosen region.
136+
* @param src_dst Selects which region to index: 0 for source (src_range), 1 for destination (dst_range).
137+
* @return Array<PrimExpr> Index expressions for every dimension of the selected region, in original dimension order.
138+
*
139+
* @note The function checks that the number of provided iter vars equals the number of non-singleton
140+
* extents; it will abort (ICHECK) if they differ.
141+
*/
81142
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
82143
int src_dst) const {
83144
Array<PrimExpr> indices;
@@ -97,6 +158,31 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
97158
return indices;
98159
}
99160

161+
/**
162+
* @brief Build a combined bound-check predicate for indexed access.
163+
*
164+
* Constructs an AND'd predicate ensuring each non-singleton index (derived from
165+
* `ivs`) stays within [0, extent) for the selected operand (source when
166+
* `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
167+
* range list this produces two conditions:
168+
* - range.min + iv >= 0
169+
* - range.min + iv < extent
170+
*
171+
* Conditions that the analyzer can prove (with symbolic bounds) are omitted.
172+
* If no uncertain conditions remain, an empty PrimExpr is returned.
173+
*
174+
* Note: the function ICHECKs that `extents.size()` equals the number of ranges
175+
* for the selected operand.
176+
*
177+
* @param ivs Iteration variables corresponding to non-singleton extents (order
178+
* matches the non-unit ranges of the chosen operand).
179+
* @param extents Per-dimension upper bounds to check against; must have the
180+
* same size as the selected range list.
181+
* @param src_dst Selects which ranges to validate: 0 => `src_range`, else
182+
* `dst_range`.
183+
* @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
184+
* an empty PrimExpr when no checks are required.
185+
*/
100186
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
101187
const Array<IterVar> &ivs,
102188
Array<PrimExpr> extents,
@@ -128,6 +214,30 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
128214
}
129215
}
130216

217+
/**
218+
* @brief Build a SIMT-style loop nest that performs element-wise atomic additions from src to dst.
219+
*
220+
* Constructs a nested loop (parallelized per iter var) that loads a value from the source buffer,
221+
* optionally casts it to the destination dtype, and performs an extern atomic add into the destination
222+
* buffer address. For scalar (zero-dimensional) operations a trivial serial For with a single
223+
* BufferStore is returned.
224+
*
225+
* The method:
226+
* - Creates iter vars for all non-singleton extents and binds them into the provided analyzer.
227+
* - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
228+
* - Computes indexed accesses and emits optional bound predicates; out-of-bounds accesses are
229+
* masked to zero when predicates are uncertain.
230+
* - Emits an extern `call_extern("AtomicAdd", address_of(dst_value), src_value)` call wrapped in
231+
* an Evaluate statement.
232+
* - Wraps the body with a parallel For at each loop level. If `coalesced_width` is defined it is
233+
* attached as the "coalesced_width" annotation on each loop.
234+
*
235+
* Note: This function mutates the analyzer binding state by binding loop variables and may fail
236+
* via ICHECK if internal assumptions about shapes are violated.
237+
*
238+
* @return A nested For loop (parallel loops) implementing the atomic-add kernel. For scalar cases
239+
* a serial For of extent 1 is returned.
240+
*/
131241
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
132242
Array<IterVar> loop_vars = MakeIterVars();
133243
bool is_scalar = loop_vars.size() == 0;
@@ -191,6 +301,31 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
191301
return Downcast<For>(body);
192302
}
193303

304+
/**
305+
* @brief Lower the atomic-add top-level operator into a parallel, vectorized TIR loop.
306+
*
307+
* Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs layout
308+
* inference at multiple levels, partitions the root loop by the provided thread variable,
309+
* vectorizes the thread loop, and returns the final (optionally predicate-guarded) statement.
310+
*
311+
* The lowering pipeline:
312+
* - Build the SIMT loop via MakeSIMTLoop.
313+
* - Fuse parallel loops into a single For and wrap as a ParallelOp.
314+
* - Run layout inference at kCommon, kStrict, and kFree levels using fields from `T`.
315+
* - Obtain the loop layout, partition the root loop with PartitionLoop by `T.thread_var`.
316+
* - Vectorize the partitioned thread loop via VectorizeLoop.
317+
* - If the ParallelOp produced a predicate for `T.thread_var`, return an IfThenElse
318+
* that guards the vectorized loop with that predicate; otherwise return the vectorized loop.
319+
*
320+
* @param T Lowering context whose fields are used:
321+
* - T.target: target architecture for layout inference and lowering decisions.
322+
* - T.thread_var: the Var used to partition the outer loop for thread-level parallelism.
323+
* - T.thread_bounds: bounds associated with the thread dimension (used during partitioning).
324+
* - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used during InferLayout.
325+
* @param analyzer Analyzer used for symbolic reasoning during partitioning and folding (omitted
326+
* from detailed param docs as a common analysis utility).
327+
* @return Stmt A lowered TIR statement representing the parallelized and vectorized atomic-add.
328+
*/
194329
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
195330
Target target = T.target;
196331
auto simt_loop = MakeSIMTLoop(analyzer);
@@ -221,6 +356,22 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
221356
return vectorized_thread_loop;
222357
}
223358

359+
/**
360+
* @brief Infer and return the layout map for the atomic add operator.
361+
*
362+
* Constructs a cached ParallelOp (by building the SIMT loop) if not already present,
363+
* validates that local.fragment layouts for src and dst match when both are provided,
364+
* and then delegates layout inference to the underlying ParallelOp.
365+
*
366+
* @param T Layout inference inputs, including an optional mapping of buffers to layouts.
367+
* @param level Inference strictness level.
368+
* @return LayoutMap The inferred layout mapping for buffers used by this operator.
369+
*
370+
* @note This method mutates the AtomicAddNode by creating and storing a ParallelOp
371+
* on first invocation.
372+
* @throws If both src and dst have layouts in `local.fragment` and their fragment
373+
* layouts differ, an ICHECK failure is raised with diagnostic output.
374+
*/
224375
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
225376
InferLevel level) const {
226377
if (!par_op_.defined()) {

src/op/atomic_add.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,63 @@
1010
#include "operator.h"
1111
#include "parallel.h"
1212

13+
/**
14+
* Lower this tile operator into a TIR statement for the given lowering context.
15+
*
16+
* @param T Lowering context containing mapped buffers and iteration information.
17+
* @param analyzer Arithmetic analyzer used to simplify and reason about expressions.
18+
* @return A TIR Stmt that implements the atomic-add tile operation for the provided context.
19+
*/
20+
/**
21+
* Infer memory/layout mapping for tensors and buffers used by this operator.
22+
*
23+
* @param T Layout inference context providing buffer and shape information.
24+
* @param level Inference aggressiveness level; higher levels may perform more speculative decisions.
25+
* @return A LayoutMap describing inferred layouts for the operator's inputs and outputs.
26+
*/
27+
/**
28+
* Get the Op registration that identifies this tile operator.
29+
*
30+
* @return A reference to the registered Op representing this operator.
31+
*/
32+
/**
33+
* Create a deep copy of this tile operator node wrapped as a TileOperator.
34+
*
35+
* @return A TileOperator handle owning a cloned AtomicAddNode.
36+
*/
37+
/**
38+
* Construct a SIMT-style For loop nest (thread/block mapping) appropriate for the operator.
39+
*
40+
* @param analyzer Arithmetic analyzer used to simplify loop bounds and predicates.
41+
* @return A For loop node representing the SIMT-parallel loop structure.
42+
*/
43+
/**
44+
* Create iteration variables used by this operator's loop nest.
45+
*
46+
* @return An array of IterVar objects describing the loop iteration axes.
47+
*/
48+
/**
49+
* Produce index expressions for either source or destination buffer access based on iteration vars.
50+
*
51+
* @param ivs IterVars created by MakeIterVars().
52+
* @param src_dst Selects which indices to produce: 0 for source indices, 1 for destination indices.
53+
* @return An array of PrimExpr index expressions suitable for indexing the selected buffer.
54+
*/
55+
/**
56+
* Build a predicate expression that guards out-of-bounds or conditional accesses for src or dst.
57+
*
58+
* @param analyzer Arithmetic analyzer used to simplify the predicate.
59+
* @param ivs IterVars created by MakeIterVars().
60+
* @param extents The loop extents corresponding to the itervars.
61+
* @param src_dst Selects which side the predicate is for: 0 for source, 1 for destination.
62+
* @return A PrimExpr boolean predicate that evaluates to true for valid iterations.
63+
*/
64+
/**
65+
* Construct an AtomicAdd tile operator from operation arguments and a buffer mapping.
66+
*
67+
* @param args Operation arguments (e.g., values or indices) specific to the atomic-add semantics.
68+
* @param vmap Mapping from buffer names to Buffer objects used by this operator.
69+
*/
1370
namespace tvm {
1471
namespace tl {
1572

0 commit comments

Comments
 (0)