@@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
102102 for (size_t i = 0 ; i < ret.size (); i++) {
103103 auto ist = analyzer.int_set (forward_index_[i] + 1 );
104104 if (arith::is_neg_inf (ist.min ()) && arith::is_pos_inf (ist.max ())) {
105- // X-OR Expression
106- ret.Set (i, input_size_[i]);
105+ // Analyzer couldn't form an IntervalSet (e.g. bitwise ops).
106+ // Fall back to ConstIntBound to derive a safe extent.
107+ auto cib = analyzer.const_int_bound (forward_index_[i]);
108+ if (cib->min_value != arith::ConstIntBound::kNegInf &&
109+ cib->max_value != arith::ConstIntBound::kPosInf &&
110+ cib->min_value >= 0 ) {
111+ // extent = max - min + 1, using 64-bit integer literal
112+ ret.Set (i, Integer (cib->max_value - cib->min_value + 1 ));
113+ } else {
114+ // Last-resort conservative fallback to avoid OOB/crash
115+ // Prefer to keep dimension from known input_size_ if available.
116+ if (i < input_size_.size ()) {
117+ ret.Set (i, input_size_[i]);
118+ } else {
119+ ret.Set (i, Integer (1 ));
120+ }
121+ }
107122 } else {
108- // CHECK(is_one(ist.min())) << ist.min();
109123 ret.Set (i, ist.max ());
110124 }
111125 }
@@ -282,10 +296,156 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
282296 return {Layout (outputs_shape, backward_index), level};
283297}
284298
299+ Layout LayoutNode::Reshape (const Array<PrimExpr> &shape,
300+ arith::Analyzer *analyzer) const {
301+ // Fast path: if shape is the same, return the original layout
302+ if (StructuralEqual ()(InputShape (), shape)) {
303+ return ffi::GetRef<Layout>(this );
304+ }
305+
306+ // Step 1. Prove the product of InputShape is equal to the product of shape
307+ PrimExpr input_shape_product = Integer (1 );
308+ for (const auto &dim : InputShape ()) {
309+ input_shape_product *= dim;
310+ }
311+ PrimExpr shape_product = Integer (1 );
312+ for (const auto &dim : shape) {
313+ shape_product *= dim;
314+ }
315+
316+ if (analyzer) {
317+ ICHECK (analyzer->CanProveEqual (input_shape_product, shape_product))
318+ << " InputShape() = " << InputShape () << " shape = " << shape;
319+ } else {
320+ arith::Analyzer local_analyzer;
321+ ICHECK (local_analyzer.CanProveEqual (input_shape_product, shape_product))
322+ << " InputShape() = " << InputShape () << " shape = " << shape;
323+ }
324+
325+ // Step 2. Create new forward indices by reshaping
326+ // For each dimension in the new shape, we create a placeholder variable
327+ Array<Var> new_vars;
328+ for (size_t i = 0 ; i < shape.size (); ++i) {
329+ new_vars.push_back (InputPlaceholder (i));
330+ }
331+ // Step 3. Compute the flat index from new shape indices
332+ // flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
333+ PrimExpr flat_index = Integer (0 );
334+ for (size_t i = 0 ; i < shape.size (); ++i) {
335+ PrimExpr stride = Integer (1 );
336+ for (size_t j = i + 1 ; j < shape.size (); ++j) {
337+ stride = stride * shape[j];
338+ }
339+ flat_index = flat_index + new_vars[i] * stride;
340+ }
341+ // Step 4. Convert flat index back to original shape indices
342+ // For original shape [s0, s1, ..., sm]:
343+ // i0 = flat_index // (s1 * s2 * ... * sm)
344+ // i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
345+ // ...
346+ Array<PrimExpr> original_indices;
347+ PrimExpr remaining = flat_index;
348+ for (size_t i = 0 ; i < InputShape ().size (); ++i) {
349+ PrimExpr stride = Integer (1 );
350+ for (size_t j = i + 1 ; j < InputShape ().size (); ++j) {
351+ stride = stride * InputShape ()[j];
352+ }
353+ original_indices.push_back (floordiv (remaining, stride));
354+ remaining = floormod (remaining, stride);
355+ }
356+ // Step 5. Substitute original indices into forward_index_
357+ Array<PrimExpr> new_forward_index;
358+ for (const auto &fwd_expr : forward_index_) {
359+ PrimExpr substituted = fwd_expr;
360+ // Replace each InputPlaceholder(i) with original_indices[i]
361+ for (size_t i = 0 ; i < InputShape ().size (); ++i) {
362+ substituted =
363+ Substitute (substituted, {{InputPlaceholder (i), original_indices[i]}});
364+ }
365+ new_forward_index.push_back (substituted);
366+ }
367+ return Layout (shape, new_forward_index);
368+ }
369+
370+ Layout FragmentNode::Reshape (const Array<PrimExpr> &shape,
371+ arith::Analyzer *analyzer) const {
372+ // Fast path: identical input shape, return self
373+ if (StructuralEqual ()(InputShape (), shape)) {
374+ return ffi::GetRef<Fragment>(this );
375+ }
376+
377+ // 1) Prove total number of elements remains the same
378+ PrimExpr input_prod = Integer (1 );
379+ for (const auto &d : InputShape ())
380+ input_prod *= d;
381+ PrimExpr shape_prod = Integer (1 );
382+ for (const auto &d : shape)
383+ shape_prod *= d;
384+
385+ if (analyzer) {
386+ ICHECK (analyzer->CanProveEqual (input_prod, shape_prod))
387+ << " InputShape() = " << InputShape () << " shape = " << shape
388+ << " input fragment layout is = " << DebugOutput ();
389+ } else {
390+ arith::Analyzer local_analyzer;
391+ ICHECK (local_analyzer.CanProveEqual (input_prod, shape_prod))
392+ << " InputShape() = " << InputShape () << " shape = " << shape;
393+ }
394+
395+ // 2) Build flat index from new-shape indices
396+ Array<Var> new_vars;
397+ new_vars.reserve (shape.size ());
398+ for (size_t i = 0 ; i < shape.size (); ++i)
399+ new_vars.push_back (InputPlaceholder (i));
400+
401+ PrimExpr flat = Integer (0 );
402+ for (size_t i = 0 ; i < shape.size (); ++i) {
403+ PrimExpr stride = Integer (1 );
404+ for (size_t j = i + 1 ; j < shape.size (); ++j)
405+ stride = stride * shape[j];
406+ flat = flat + new_vars[i] * stride;
407+ }
408+
409+ // 3) Recover original indices from flat index
410+ Array<PrimExpr> orig_indices;
411+ PrimExpr remain = flat;
412+ for (size_t i = 0 ; i < InputShape ().size (); ++i) {
413+ PrimExpr stride = Integer (1 );
414+ for (size_t j = i + 1 ; j < InputShape ().size (); ++j)
415+ stride = stride * InputShape ()[j];
416+ orig_indices.push_back (floordiv (remain, stride));
417+ remain = floormod (remain, stride);
418+ }
419+
420+ // 4) Substitute old placeholders with expressions of new indices
421+ Array<PrimExpr> new_forward_index;
422+ for (const auto &e : forward_index_) {
423+ PrimExpr cur = e;
424+ for (size_t i = 0 ; i < InputShape ().size (); ++i) {
425+ cur = Substitute (cur, {{InputPlaceholder (i), orig_indices[i]}});
426+ }
427+ new_forward_index.push_back (cur);
428+ }
429+
430+ PrimExpr new_forward_thread = forward_thread_;
431+ for (size_t i = 0 ; i < InputShape ().size (); ++i) {
432+ new_forward_thread = Substitute (new_forward_thread,
433+ {{InputPlaceholder (i), orig_indices[i]}});
434+ }
435+
436+ Fragment reshaped (shape, new_forward_index, new_forward_thread,
437+ ReplicateExtent (), std::nullopt );
438+ if (thread_range_.defined ()) {
439+ reshaped = reshaped->BindThreadRange (thread_range_);
440+ }
441+ return reshaped;
442+ }
443+
285444Layout LayoutNode::Inverse () const {
286445 auto inverse_result = InverseWithLevel ();
287446 return std::move (inverse_result.first );
288447}
448+
289449PrimExpr infer_fragment_index (const Map<Var, Range> &input_iters,
290450 const PrimExpr &forward_thread,
291451 arith::Analyzer *analyzer) {
0 commit comments