@@ -119,6 +119,14 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
119119 Map<Buffer, Layout> layout_map_;
120120};
121121
122+ /* *
123+ * @brief Handle a parallel For node during traversal, collecting loop metadata.
124+ *
125+ * Visits a parallel loop, asserts the loop is parallel, records a data-parallel
126+ * IterVar for the loop, binds the loop variable range into the analyzer scope,
127+ * and extracts any reducer information from the loop's annotations into the
128+ * visitor's reducer_info_map_. Continues traversal into the loop body.
129+ */
122130void ParallelLoopNestVisitor::VisitStmt_ (const ForNode *op) {
123131 ICHECK (op->kind == ForKind::kParallel );
124132 p->loop_vars_ .push_back (
@@ -147,19 +155,6 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
147155 StmtExprVisitor::VisitStmt_ (op);
148156}
149157
150- /* *
151- * @brief Visit a BufferLoad node and record/validate index mapping for
152- * fragment-local buffers.
153- *
154- * If the loaded buffer's scope is "local.fragment", this records the load
155- * indices in the visitor's indice_map_ when seen for the first time. If an
156- * entry already exists, the previously recorded indices are asserted
157- * structurally equal to the current indices.
158- *
159- * This ensures all accesses to the same fragment-local buffer within the
160- * parallel loop use a consistent index map. The function then continues
161- * standard expression visitation.
162- */
163158void ParallelLoopNestVisitor::VisitExpr_ (const BufferLoadNode *op) {
164159 if (op->buffer .scope () == " local.fragment" ) {
165160 if (p->indice_map_ .find (op->buffer ) != p->indice_map_ .end ()) {
@@ -173,91 +168,42 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
173168 StmtExprVisitor::VisitExpr_ (op);
174169}
175170
176- /* *
177- * @brief Construct a ParallelOpNode from a parallel loop nest root.
178- *
179- * Initializes the node with the given For loop as the root of the parallel
180- * operator and immediately runs the internal ParallelLoopNestVisitor to collect
181- * loop and buffer access information from the nested body.
182- *
183- * @param root The root For node representing the parallel loop nest to be
184- * analyzed.
185- */
186171ParallelOpNode::ParallelOpNode (For root) : root_(root), V(this ) {
187172 V.VisitStmt (root);
188173}
189174
190- /* *
191- * @brief Create a copy of this ParallelOpNode wrapped as a TileOperator.
192- *
193- * Returns a new TileOperator that holds a deep copy of this ParallelOpNode.
194- *
195- * @return TileOperator A TileOperator owning a copy of this node.
196- */
197175TileOperator ParallelOpNode::Clone () const {
198176 auto op = make_object<ParallelOpNode>(*this );
199177 return ParallelOp (op);
200178}
201179
202- /* *
203- * @brief No-op lowering: return the stored root statement unchanged.
204- *
205- * This implementation does not perform any transformation and returns the
206- * operator's original root For statement as-is.
207- *
208- * @param T Lowering arguments (unused).
209- * @return Stmt The original root statement held by this ParallelOpNode.
210- */
211180Stmt ParallelOpNode::Lower (const LowerArgs &T,
212181 arith::Analyzer *analyzer) const {
213182 return root_;
214183}
215184
216- /* *
217- * @brief Check whether a buffer is indexed by the loop's canonical (common)
218- * iteration variables.
219- *
220- * Returns true if the recorded index mapping for `buffer` is structurally equal
221- * to the sequence of loop iteration variables for this parallel op (i.e., the
222- * buffer is accessed using the common access indices of the loop nest).
223- *
224- * @param buffer The buffer to check.
225- * @return true if the buffer's index map equals the loop's iteration variables;
226- * false otherwise.
227- */
228185bool ParallelOpNode::IsCommonAccessIndice (const Buffer &buffer) const {
229186 auto common_indice = loop_vars_.Map ([](const auto &iv) { return iv->var ; });
230187 return StructuralEqual ()(indice_map_[buffer], common_indice);
231188}
232189
233- /* *
234- * @brief Infer buffer layouts for a Parallel operator based on the chosen
235- * inference level.
190+ /* ! \brief Infer the layout for parallel operations based on different inference
191+ * levels
236192 *
237- * Attempts to compute a consistent LayoutMap for buffers accessed by a parallel
238- * loop (root_) using explicit input layouts (T.layout_map), thread bounds
239- * (T.thread_bounds), and optional buffer remapping/vectorization information in
240- * T. Behavior depends on the supplied InferLevel:
241- * - kStrict: only accept pre-existing loop_layout_ (no inference).
242- * - kCommon: allow inference from explicit buffer fragments when available.
243- * - kFree: attempt more aggressive inference (derive loop partition from
244- * read/write fragments, plan partitioning from vectorization/thread bounds, and
245- * add predicates to constrain replication when necessary).
193+ * The inference level controls how aggressively we try to infer and optimize
194+ * layouts:
195+ * - kStrict (2): Most conservative level. Only allows explicitly defined
196+ * layouts. Returns empty layout map if loop_layout_ is not already defined.
197+ * Used when exact layout control is required.
246198 *
247- * This method may mutate the node's internal state (sets loop_layout_ when
248- * inferred and registers predicates via AddPredicate) and consults analyzer_
249- * for symbolic proofs .
199+ * - kCommon (1): Intermediate level between strict and free.
200+ * Allows common layout patterns while maintaining some
201+ * constraints .
250202 *
251- * @param T Container of auxiliary inputs used for inference (buffer_remap,
252- * layout_map, and thread_bounds). The function uses T.layout_map for source
253- * fragments and T.thread_bounds to bind thread-range information in inferred
254- * fragments.
255- * @param level Controls inference aggressiveness (kStrict, kCommon, kFree).
256- * @return LayoutMap A map of buffers to inferred Fragment layouts for buffers
257- * that did not already have layouts in T.layout_map. Returns an empty map when
258- * no inference was performed.
259- * @throws LayoutConflictException If a computed loop partition conflicts with
260- * an existing buffer fragment (incompatible thread mappings).
203+ * - kFree (0): Most permissive level. Allows maximum optimization freedom.
204+ * Will attempt layout inference even without source buffers.
205+ * Can generate new layouts based on vectorization and thread
206+ * bounds. Used when maximum performance optimization is desired.
261207 */
262208LayoutMap ParallelOpNode::InferLayout (const LayoutInferArgs &T,
263209 InferLevel level) const {
@@ -446,20 +392,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
446392 return results;
447393}
448394
449- /* *
450- * @brief Retrieve the loop's thread predicate with the thread variable
451- * substituted.
452- *
453- * If a predicate is set for this ParallelOpNode, returns a copy of that
454- * predicate where the placeholder input (InputPlaceholder(0)) is replaced by
455- * the provided thread_var. If no predicate is defined, returns an empty
456- * Optional.
457- *
458- * @param thread_var The thread loop variable to substitute for the predicate's
459- * input placeholder.
460- * @return Optional<PrimExpr> The substituted predicate expression, or
461- * std::nullopt if none is defined.
462- */
463395Optional<PrimExpr> ParallelOpNode::GetPredicate (Var thread_var) const {
464396 if (predicate_.defined ()) {
465397 return Substitute (predicate_.value (), {{InputPlaceholder (0 ), thread_var}});
@@ -468,32 +400,6 @@ Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
468400 }
469401}
470402
471- /* *
472- * @brief Construct the complete fragment layout for a buffer within the
473- * parallel loop.
474- *
475- * Given a buffer referenced inside the parallel loop, return a Fragment that
476- * maps the buffer's logical indices to the loop's thread space and replication
477- * extent.
478- *
479- * Detailed behavior:
480- * - Precondition: a loop layout (loop_layout_) must be defined.
481- * - If the buffer uses the common access indices of the loop, the loop's
482- * fragment is returned directly.
483- * - Otherwise, the function:
484- * - Computes the buffer's bijective index by appending the flattened
485- * replication expression for unused iterators.
486- * - Inverts that bijection to obtain the replication extent of the buffer's
487- * index space and combines it with the loop's replication extent to produce the
488- * destination replication extent.
489- * - Builds forward index placeholders for the buffer elements and maps them
490- * through the inverted layout and the loop layout to derive the thread binding.
491- * - Returns a Fragment with the computed thread binding and combined
492- * replication extent, with replicate variables condensed.
493- *
494- * @return Fragment The completed fragment describing thread binding and
495- * replication extent for `buffer`.
496- */
497403Fragment ParallelOpNode::CompleteBufferFragment (const Buffer &buffer) const {
498404 ICHECK (loop_layout_.defined ());
499405 if (IsCommonAccessIndice (buffer)) {
0 commit comments