@@ -313,20 +313,21 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
313313 shape_product *= dim;
314314 }
315315
316- if (analyzer) {
317- ICHECK (analyzer->CanProveEqual (input_shape_product, shape_product))
318- << " InputShape() = " << InputShape () << " shape = " << shape;
319- } else {
320- arith::Analyzer local_analyzer;
321- ICHECK (local_analyzer.CanProveEqual (input_shape_product, shape_product))
322- << " InputShape() = " << InputShape () << " shape = " << shape;
323- }
316+ // Use provided analyzer if present, otherwise a local fallback to avoid
317+ // potential null dereference paths flagged by static analysis.
318+ arith::Analyzer fallback_analyzer;
319+ arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
320+ ICHECK (az->CanProveEqual (input_shape_product, shape_product))
321+ << " InputShape() = " << InputShape () << " shape = " << shape;
324322
325323 // Step 2. Create new forward indices by reshaping
326324 // For each dimension in the new shape, we create a placeholder variable
327325 Array<Var> new_vars;
326+ new_vars.reserve (shape.size ());
328327 for (size_t i = 0 ; i < shape.size (); ++i) {
329- new_vars.push_back (InputPlaceholder (i));
328+ auto var = Var (std::string (" n_" ) + std::to_string (i), shape[i].dtype ());
329+ az->Bind (var, Range (0 , shape[i]));
330+ new_vars.push_back (var);
330331 }
331332 // Step 3. Compute the flat index from new shape indices
332333 // flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
@@ -362,7 +363,11 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
362363 substituted =
363364 Substitute (substituted, {{InputPlaceholder (i), original_indices[i]}});
364365 }
365- new_forward_index.push_back (substituted);
366+ new_forward_index.push_back (az->Simplify (substituted));
367+ }
368+ for (size_t i = 0 ; i < new_vars.size (); ++i) {
369+ new_forward_index =
370+ Substitute (new_forward_index, {{new_vars[i], InputPlaceholder (i)}});
366371 }
367372 return Layout (shape, new_forward_index);
368373}
@@ -382,21 +387,25 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
382387 for (const auto &d : shape)
383388 shape_prod *= d;
384389
385- if (analyzer) {
386- ICHECK (analyzer->CanProveEqual (input_prod, shape_prod))
387- << " InputShape() = " << InputShape () << " shape = " << shape
388- << " input fragment layout is = " << DebugOutput ();
389- } else {
390- arith::Analyzer local_analyzer;
391- ICHECK (local_analyzer.CanProveEqual (input_prod, shape_prod))
392- << " InputShape() = " << InputShape () << " shape = " << shape;
393- }
390+ // Use provided analyzer if present, otherwise a local fallback.
391+ arith::Analyzer fallback_analyzer;
392+ arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
393+ ICHECK (az->CanProveEqual (input_prod, shape_prod))
394+ << " InputShape() = " << InputShape () << " shape = " << shape
395+ << " input fragment layout is = " << DebugOutput ();
394396
395397 // 2) Build flat index from new-shape indices
396398 Array<Var> new_vars;
397399 new_vars.reserve (shape.size ());
398- for (size_t i = 0 ; i < shape.size (); ++i)
399- new_vars.push_back (InputPlaceholder (i));
400+ for (size_t i = 0 ; i < shape.size (); ++i) {
401+ // Cannot use InputPlaceholder(i) here, because it would cause name capture
402+ // (variable capture) with InputPlaceholder(i) in upper scopes. Therefore,
403+ // we must create a fresh variable here to avoid confusion when
404+ // substituting.
405+ auto var = Var (std::string (" n_" ) + std::to_string (i), shape[i].dtype ());
406+ az->Bind (var, Range (0 , shape[i]));
407+ new_vars.push_back (var);
408+ }
400409
401410 PrimExpr flat = Integer (0 );
402411 for (size_t i = 0 ; i < shape.size (); ++i) {
@@ -405,7 +414,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
405414 stride = stride * shape[j];
406415 flat = flat + new_vars[i] * stride;
407416 }
408-
409417 // 3) Recover original indices from flat index
410418 Array<PrimExpr> orig_indices;
411419 PrimExpr remain = flat;
@@ -416,23 +424,29 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
416424 orig_indices.push_back (floordiv (remain, stride));
417425 remain = floormod (remain, stride);
418426 }
419-
420427 // 4) Substitute old placeholders with expressions of new indices
421428 Array<PrimExpr> new_forward_index;
422429 for (const auto &e : forward_index_) {
423430 PrimExpr cur = e;
424431 for (size_t i = 0 ; i < InputShape ().size (); ++i) {
425432 cur = Substitute (cur, {{InputPlaceholder (i), orig_indices[i]}});
426433 }
434+ cur = az->Simplify (cur);
427435 new_forward_index.push_back (cur);
428436 }
429-
430437 PrimExpr new_forward_thread = forward_thread_;
431438 for (size_t i = 0 ; i < InputShape ().size (); ++i) {
432439 new_forward_thread = Substitute (new_forward_thread,
433440 {{InputPlaceholder (i), orig_indices[i]}});
434441 }
435-
442+ new_forward_thread = az->Simplify (new_forward_thread);
443+ for (size_t i = 0 ; i < new_vars.size (); ++i) {
444+ auto var = new_vars[i];
445+ new_forward_index =
446+ Substitute (new_forward_index, {{var, InputPlaceholder (i)}});
447+ new_forward_thread =
448+ Substitute (new_forward_thread, {{var, InputPlaceholder (i)}});
449+ }
436450 Fragment reshaped (shape, new_forward_index, new_forward_thread,
437451 ReplicateExtent (), std::nullopt );
438452 if (thread_range_.defined ()) {
0 commit comments