Skip to content

Commit 031b834

Browse files
📝 Add docstrings to reducer_0825
Docstrings generation was requested by @LeiWang1999. * #757 (comment) The following files were modified: * `setup.py` * `src/op/builtin.h` * `src/op/finalize_reducer.cc` * `src/op/finalize_reducer.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/target/codegen_cuda.cc` * `src/tl_templates/cuda/common.h` * `src/transform/layout_inference.cc` * `src/transform/layout_reducer.cc` * `src/transform/layout_reducer.h` * `src/transform/merge_shared_memory_allocations.cc` * `src/transform/storage_access.cc` * `src/transform/warp_specialized_rewriter.cc` * `testing/python/autotune/test_tilelang_autotune_with_inputs.py` * `tilelang/engine/phase.py` * `tilelang/language/customize.py` * `tilelang/language/reduce.py` * `tilelang/transform/__init__.py`
1 parent 2af3f22 commit 031b834

20 files changed

+749
-632
lines changed

setup.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,9 +749,20 @@ def build_cython(self, ext):
749749

750750
def build_cmake(self, ext):
751751
"""
752-
Build a single CMake-based extension.
753-
754-
:param ext: The extension (an instance of CMakeExtension).
752+
Build a single CMake-based extension by generating a CMake config and invoking CMake/Ninja.
753+
754+
Generates or updates a config.cmake in the build directory (based on the extension's sourcedir),
755+
injecting LLVM/CUDA/ROCm and Python settings, then runs CMake to configure and build the target.
756+
When running an in-place build the resulting library is placed under ./tilelang/lib; otherwise the
757+
standard extension output directory is used.
758+
759+
Parameters:
760+
ext: The CMakeExtension to build; its `sourcedir` should contain the TVM/CMake `config.cmake`
761+
template under `3rdparty/tvm/cmake/`.
762+
763+
Raises:
764+
subprocess.CalledProcessError: If the CMake configuration or build commands fail.
765+
OSError: If filesystem operations (read/write) fail.
755766
"""
756767
# Only setup LLVM if it's enabled
757768
llvm_config_path = "OFF"

src/op/builtin.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
#include <tvm/ir/transform.h>
1212

1313
namespace tvm {
14+
/*!
15+
* \brief Create the TVM intrinsic that initializes a PTX fence barrier.
16+
*
17+
* Initializes a PTX fence-style barrier used to coordinate asynchronous memory
18+
* operations (for example, TMA/TMA_STORE). Returns the Op representing this
19+
* intrinsic for use in TIR lowering and code generation.
20+
*
21+
*/
1422
namespace tl {
1523

1624
namespace attr {

src/op/finalize_reducer.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,50 @@ namespace tl {
1818

1919
using namespace tir;
2020

21+
/**
22+
* @brief Construct a FinalizeReducerOp from TL operator arguments and a buffer map.
23+
*
24+
* Extracts the reducer Buffer from `vmap` using the variable referenced by `args[0]`
25+
* and sets the reduction operation type from the integer code in `args[1]`.
26+
*
27+
* @param args TL operator arguments: expects at least two elements where
28+
* `args[0]` is an access pointer identifying the reducer variable and
29+
* `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min).
30+
* @param vmap Mapping from variables to Buffers used to look up the reducer Buffer.
31+
*/
2132
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
2233
auto node = make_object<FinalizeReducerOpNode>();
2334
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
2435
node->op = (ReducerOpType)*as_const_int(args[1]);
2536
data_ = std::move(node);
2637
}
2738

39+
/**
40+
* @brief Lower the finalize_reducer TL operator to a TIR statement.
41+
*
42+
* Lowers the operator that finalizes a reducer by performing a thread-wide AllReduce
43+
* across the reducer's output elements and writing the reduced value back into the
44+
* reducer buffer. The function:
45+
* - Fetches the reducer buffer and expects its layout to be a Fragment.
46+
* - Builds index Vars for each output dimension.
47+
* - Reads the layout's ReplicateExtent and:
48+
* - if extent == 1, emits a no-op Evaluate(0);
49+
* - otherwise constructs an AllReduce extern call (uses `run_hopper` when the
50+
* compilation target is Hopper) with an optional workspace (allocated via
51+
* T.AddWorkspace when reducing_threads >= 32) and stores the result via
52+
* BufferStore.
53+
* - Wraps the store in parallel outer For loops over each output dimension.
54+
*
55+
* @param T Lowering context containing buffer remapping, layout map, thread bounds,
56+
* target, and helper methods (e.g., AddWorkspace).
57+
* @param analyzer Arithmetic analyzer (unused by this implementation but provided
58+
* for consistency with lowering API).
59+
* @return Stmt The lowered TIR statement representing the AllReduce and surrounding loops.
60+
*
61+
* @note The function ICHECKs that the reducer layout is present and a Fragment,
62+
* and that ReplicateExtent is either 1 or equal to the thread block extent;
63+
* violations cause a fatal check failure.
64+
*/
2865
Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
2966
arith::Analyzer *analyzer) const {
3067
auto buffer = T.buffer_remap[reducer];
@@ -81,13 +118,32 @@ Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
81118
return body;
82119
}
83120

121+
/**
122+
* @brief Infer and return the layout mapping for the reducer buffer.
123+
*
124+
* Copies the existing layout for the reducer from the provided LayoutInferArgs into
125+
* a new LayoutMap and returns it. The inference does not modify the layout; it
126+
* preserves the reducer's current layout.
127+
*
128+
* @param T Provides the input layout map from which the reducer's layout is copied.
129+
* @param level Unused by this operator; present for API compatibility.
130+
* @return LayoutMap A map that contains the reducer buffer mapped to its original layout.
131+
*/
84132
LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
85133
InferLevel level) const {
86134
LayoutMap layout_map;
87135
layout_map.Set(reducer, T.layout_map.Get(reducer).value());
88136
return layout_map;
89137
}
90138

139+
/**
140+
* @brief Create a deep copy of this FinalizeReducerOpNode and wrap it as a TileOperator.
141+
*
142+
* Constructs a new FinalizeReducerOpNode by copying the current node state and returns
143+
* a TileOperator that owns the copied node.
144+
*
145+
* @return TileOperator A TileOperator that contains a deep copy of this node.
146+
*/
91147
TileOperator FinalizeReducerOpNode::Clone() const {
92148
auto node = make_object<FinalizeReducerOpNode>(*this);
93149
return TileOperator(node);

src/op/finalize_reducer.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,69 @@
1212
#include "../transform/layout_reducer.h"
1313
#include "./operator.h"
1414

15+
/**
16+
* FinalizeReducer operator node for Tile IR.
17+
*
18+
* Represents a TL-level operator that finalizes a reducer buffer into a
19+
* result using a specified reducer operation.
20+
*
21+
* Public members:
22+
* - reducer: the tir::Buffer that holds the intermediate reduction values.
23+
* - op: the reducer operation to apply when finalizing values.
24+
*/
25+
26+
/**
27+
* Lower this operator to a TIR statement.
28+
*
29+
* @param T Lowering arguments (buffers, indices, and other lowering context).
30+
* @param analyzer Arithmetic analyzer used to simplify expressions during lowering.
31+
* @return A tir::Stmt that implements the finalize-reducer semantics for the provided
32+
* lowering context.
33+
*/
34+
35+
/**
36+
* Infer layout mapping for this operator.
37+
*
38+
* Determines how input and output buffer layouts relate for the finalize-reducer
39+
* operator at the given inference level.
40+
*
41+
* @param T Layout inference arguments (including operand layouts and shapes).
42+
* @param level Inference precision level.
43+
* @return A LayoutMap describing the inferred layouts.
44+
*/
45+
46+
/**
47+
* Get the singleton Op object representing this operator.
48+
*
49+
* @return A reference to the Op describing FinalizeReducer.
50+
*/
51+
52+
/**
53+
* Create a deep copy of this operator node as a TileOperator.
54+
*
55+
* @return A TileOperator handle that is an independent clone of this node.
56+
*/
57+
58+
/**
59+
* Public wrapper for FinalizeReducerOpNode.
60+
*
61+
* Provides the reference semantics and construction API used by callers.
62+
*/
63+
64+
/**
65+
* Construct a FinalizeReducerOp from TL-level arguments.
66+
*
67+
* @param args Positional primitive expressions that parameterize the operator
68+
* (e.g., shapes, axis indices). Documented where their meaning is
69+
* not obvious from name or type in call sites.
70+
* @param vmap Mapping from operand names to tir::Buffer instances used by this operator.
71+
*/
72+
73+
/**
74+
* Get the Op singleton for the public FinalizeReducerOp handle.
75+
*
76+
* @return A reference to the Op describing FinalizeReducer.
77+
*/
1578
namespace tvm {
1679
namespace tl {
1780

src/op/parallel.cc

Lines changed: 22 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
119119
Map<Buffer, Layout> layout_map_;
120120
};
121121

122+
/**
123+
* @brief Handle a parallel For node during traversal, collecting loop metadata.
124+
*
125+
* Visits a parallel loop, asserts the loop is parallel, records a data-parallel
126+
* IterVar for the loop, binds the loop variable range into the analyzer scope,
127+
* and extracts any reducer information from the loop's annotations into the
128+
* visitor's reducer_info_map_. Continues traversal into the loop body.
129+
*/
122130
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
123131
ICHECK(op->kind == ForKind::kParallel);
124132
p->loop_vars_.push_back(
@@ -147,19 +155,6 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
147155
StmtExprVisitor::VisitStmt_(op);
148156
}
149157

150-
/**
151-
* @brief Visit a BufferLoad node and record/validate index mapping for
152-
* fragment-local buffers.
153-
*
154-
* If the loaded buffer's scope is "local.fragment", this records the load
155-
* indices in the visitor's indice_map_ when seen for the first time. If an
156-
* entry already exists, the previously recorded indices are asserted
157-
* structurally equal to the current indices.
158-
*
159-
* This ensures all accesses to the same fragment-local buffer within the
160-
* parallel loop use a consistent index map. The function then continues
161-
* standard expression visitation.
162-
*/
163158
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
164159
if (op->buffer.scope() == "local.fragment") {
165160
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
@@ -173,91 +168,42 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
173168
StmtExprVisitor::VisitExpr_(op);
174169
}
175170

176-
/**
177-
* @brief Construct a ParallelOpNode from a parallel loop nest root.
178-
*
179-
* Initializes the node with the given For loop as the root of the parallel
180-
* operator and immediately runs the internal ParallelLoopNestVisitor to collect
181-
* loop and buffer access information from the nested body.
182-
*
183-
* @param root The root For node representing the parallel loop nest to be
184-
* analyzed.
185-
*/
186171
ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
187172
V.VisitStmt(root);
188173
}
189174

190-
/**
191-
* @brief Create a copy of this ParallelOpNode wrapped as a TileOperator.
192-
*
193-
* Returns a new TileOperator that holds a deep copy of this ParallelOpNode.
194-
*
195-
* @return TileOperator A TileOperator owning a copy of this node.
196-
*/
197175
TileOperator ParallelOpNode::Clone() const {
198176
auto op = make_object<ParallelOpNode>(*this);
199177
return ParallelOp(op);
200178
}
201179

202-
/**
203-
* @brief No-op lowering: return the stored root statement unchanged.
204-
*
205-
* This implementation does not perform any transformation and returns the
206-
* operator's original root For statement as-is.
207-
*
208-
* @param T Lowering arguments (unused).
209-
* @return Stmt The original root statement held by this ParallelOpNode.
210-
*/
211180
Stmt ParallelOpNode::Lower(const LowerArgs &T,
212181
arith::Analyzer *analyzer) const {
213182
return root_;
214183
}
215184

216-
/**
217-
* @brief Check whether a buffer is indexed by the loop's canonical (common)
218-
* iteration variables.
219-
*
220-
* Returns true if the recorded index mapping for `buffer` is structurally equal
221-
* to the sequence of loop iteration variables for this parallel op (i.e., the
222-
* buffer is accessed using the common access indices of the loop nest).
223-
*
224-
* @param buffer The buffer to check.
225-
* @return true if the buffer's index map equals the loop's iteration variables;
226-
* false otherwise.
227-
*/
228185
bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
229186
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
230187
return StructuralEqual()(indice_map_[buffer], common_indice);
231188
}
232189

233-
/**
234-
* @brief Infer buffer layouts for a Parallel operator based on the chosen
235-
* inference level.
190+
/*! \brief Infer the layout for parallel operations based on different inference
191+
* levels
236192
*
237-
* Attempts to compute a consistent LayoutMap for buffers accessed by a parallel
238-
* loop (root_) using explicit input layouts (T.layout_map), thread bounds
239-
* (T.thread_bounds), and optional buffer remapping/vectorization information in
240-
* T. Behavior depends on the supplied InferLevel:
241-
* - kStrict: only accept pre-existing loop_layout_ (no inference).
242-
* - kCommon: allow inference from explicit buffer fragments when available.
243-
* - kFree: attempt more aggressive inference (derive loop partition from
244-
* read/write fragments, plan partitioning from vectorization/thread bounds, and
245-
* add predicates to constrain replication when necessary).
193+
* The inference level controls how aggressively we try to infer and optimize
194+
* layouts:
195+
* - kStrict (2): Most conservative level. Only allows explicitly defined
196+
* layouts. Returns empty layout map if loop_layout_ is not already defined.
197+
* Used when exact layout control is required.
246198
*
247-
* This method may mutate the node's internal state (sets loop_layout_ when
248-
* inferred and registers predicates via AddPredicate) and consults analyzer_
249-
* for symbolic proofs.
199+
* - kCommon (1): Intermediate level between strict and free.
200+
* Allows common layout patterns while maintaining some
201+
* constraints.
250202
*
251-
* @param T Container of auxiliary inputs used for inference (buffer_remap,
252-
* layout_map, and thread_bounds). The function uses T.layout_map for source
253-
* fragments and T.thread_bounds to bind thread-range information in inferred
254-
* fragments.
255-
* @param level Controls inference aggressiveness (kStrict, kCommon, kFree).
256-
* @return LayoutMap A map of buffers to inferred Fragment layouts for buffers
257-
* that did not already have layouts in T.layout_map. Returns an empty map when
258-
* no inference was performed.
259-
* @throws LayoutConflictException If a computed loop partition conflicts with
260-
* an existing buffer fragment (incompatible thread mappings).
203+
* - kFree (0): Most permissive level. Allows maximum optimization freedom.
204+
* Will attempt layout inference even without source buffers.
205+
* Can generate new layouts based on vectorization and thread
206+
* bounds. Used when maximum performance optimization is desired.
261207
*/
262208
LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
263209
InferLevel level) const {
@@ -446,20 +392,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
446392
return results;
447393
}
448394

449-
/**
450-
* @brief Retrieve the loop's thread predicate with the thread variable
451-
* substituted.
452-
*
453-
* If a predicate is set for this ParallelOpNode, returns a copy of that
454-
* predicate where the placeholder input (InputPlaceholder(0)) is replaced by
455-
* the provided thread_var. If no predicate is defined, returns an empty
456-
* Optional.
457-
*
458-
* @param thread_var The thread loop variable to substitute for the predicate's
459-
* input placeholder.
460-
* @return Optional<PrimExpr> The substituted predicate expression, or
461-
* std::nullopt if none is defined.
462-
*/
463395
Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
464396
if (predicate_.defined()) {
465397
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
@@ -468,32 +400,6 @@ Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
468400
}
469401
}
470402

471-
/**
472-
* @brief Construct the complete fragment layout for a buffer within the
473-
* parallel loop.
474-
*
475-
* Given a buffer referenced inside the parallel loop, return a Fragment that
476-
* maps the buffer's logical indices to the loop's thread space and replication
477-
* extent.
478-
*
479-
* Detailed behavior:
480-
* - Precondition: a loop layout (loop_layout_) must be defined.
481-
* - If the buffer uses the common access indices of the loop, the loop's
482-
* fragment is returned directly.
483-
* - Otherwise, the function:
484-
* - Computes the buffer's bijective index by appending the flattened
485-
* replication expression for unused iterators.
486-
* - Inverts that bijection to obtain the replication extent of the buffer's
487-
* index space and combines it with the loop's replication extent to produce the
488-
* destination replication extent.
489-
* - Builds forward index placeholders for the buffer elements and maps them
490-
* through the inverted layout and the loop layout to derive the thread binding.
491-
* - Returns a Fragment with the computed thread binding and combined
492-
* replication extent, with replicate variables condensed.
493-
*
494-
* @return Fragment The completed fragment describing thread binding and
495-
* replication extent for `buffer`.
496-
*/
497403
Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
498404
ICHECK(loop_layout_.defined());
499405
if (IsCommonAccessIndice(buffer)) {

0 commit comments

Comments
 (0)