Skip to content

Commit d7164ab

Browse files
authored
[Language][Reshape] Improve variable handling and ensure correctness during Layout Reshape (#1248)
* fix * Refactor tensor reshaping in fp8_lighting_indexer.py - Replaced the allocation of `s_reshaped` with a reshape operation to improve clarity and performance. - Updated the logic in the computation of `s_reshaped` to utilize the reshaped tensor, enhancing the overall functionality of the attention mechanism. * Refactor analyzer usage in Layout and Fragment reshaping - Consolidated analyzer logic in the `Reshape` methods of `LayoutNode` and `FragmentNode` to utilize a fallback analyzer, improving code clarity and preventing potential null dereference issues. - Updated variable binding and simplification calls to use the selected analyzer consistently, enhancing robustness in shape validation and index computation.
1 parent c139855 commit d7164ab

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

examples/deepseek_v32/fp8_lighting_indexer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def mqa_attn_return_logits_kernel(
127127
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
128128
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
129129
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
130-
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
130+
s_reshaped = T.reshape(s, (block_N, block_Q, heads))
131131
logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
132132
weights = T.alloc_fragment([block_Q, heads], accum_dtype)
133133

@@ -165,7 +165,7 @@ def mqa_attn_return_logits_kernel(
165165

166166
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
167167
s_reshaped[bn_i, bq_i,
168-
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
168+
h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) *
169169
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
170170

171171
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)

src/layout/layout.cc

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -313,20 +313,21 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
313313
shape_product *= dim;
314314
}
315315

316-
if (analyzer) {
317-
ICHECK(analyzer->CanProveEqual(input_shape_product, shape_product))
318-
<< "InputShape() = " << InputShape() << " shape = " << shape;
319-
} else {
320-
arith::Analyzer local_analyzer;
321-
ICHECK(local_analyzer.CanProveEqual(input_shape_product, shape_product))
322-
<< "InputShape() = " << InputShape() << " shape = " << shape;
323-
}
316+
// Use provided analyzer if present, otherwise a local fallback to avoid
317+
// potential null dereference paths flagged by static analysis.
318+
arith::Analyzer fallback_analyzer;
319+
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
320+
ICHECK(az->CanProveEqual(input_shape_product, shape_product))
321+
<< "InputShape() = " << InputShape() << " shape = " << shape;
324322

325323
// Step 2. Create new forward indices by reshaping
326324
// For each dimension in the new shape, we create a placeholder variable
327325
Array<Var> new_vars;
326+
new_vars.reserve(shape.size());
328327
for (size_t i = 0; i < shape.size(); ++i) {
329-
new_vars.push_back(InputPlaceholder(i));
328+
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
329+
az->Bind(var, Range(0, shape[i]));
330+
new_vars.push_back(var);
330331
}
331332
// Step 3. Compute the flat index from new shape indices
332333
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
@@ -362,7 +363,11 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
362363
substituted =
363364
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
364365
}
365-
new_forward_index.push_back(substituted);
366+
new_forward_index.push_back(az->Simplify(substituted));
367+
}
368+
for (size_t i = 0; i < new_vars.size(); ++i) {
369+
new_forward_index =
370+
Substitute(new_forward_index, {{new_vars[i], InputPlaceholder(i)}});
366371
}
367372
return Layout(shape, new_forward_index);
368373
}
@@ -382,21 +387,25 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
382387
for (const auto &d : shape)
383388
shape_prod *= d;
384389

385-
if (analyzer) {
386-
ICHECK(analyzer->CanProveEqual(input_prod, shape_prod))
387-
<< "InputShape() = " << InputShape() << " shape = " << shape
388-
<< " input fragment layout is = " << DebugOutput();
389-
} else {
390-
arith::Analyzer local_analyzer;
391-
ICHECK(local_analyzer.CanProveEqual(input_prod, shape_prod))
392-
<< "InputShape() = " << InputShape() << " shape = " << shape;
393-
}
390+
// Use provided analyzer if present, otherwise a local fallback.
391+
arith::Analyzer fallback_analyzer;
392+
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
393+
ICHECK(az->CanProveEqual(input_prod, shape_prod))
394+
<< "InputShape() = " << InputShape() << " shape = " << shape
395+
<< " input fragment layout is = " << DebugOutput();
394396

395397
// 2) Build flat index from new-shape indices
396398
Array<Var> new_vars;
397399
new_vars.reserve(shape.size());
398-
for (size_t i = 0; i < shape.size(); ++i)
399-
new_vars.push_back(InputPlaceholder(i));
400+
for (size_t i = 0; i < shape.size(); ++i) {
401+
// Cannot use InputPlaceholder(i) here, because it would cause name capture
402+
// (variable capture) with InputPlaceholder(i) in upper scopes. Therefore,
403+
// we must create a fresh variable here to avoid confusion when
404+
// substituting.
405+
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
406+
az->Bind(var, Range(0, shape[i]));
407+
new_vars.push_back(var);
408+
}
400409

401410
PrimExpr flat = Integer(0);
402411
for (size_t i = 0; i < shape.size(); ++i) {
@@ -405,7 +414,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
405414
stride = stride * shape[j];
406415
flat = flat + new_vars[i] * stride;
407416
}
408-
409417
// 3) Recover original indices from flat index
410418
Array<PrimExpr> orig_indices;
411419
PrimExpr remain = flat;
@@ -416,23 +424,29 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
416424
orig_indices.push_back(floordiv(remain, stride));
417425
remain = floormod(remain, stride);
418426
}
419-
420427
// 4) Substitute old placeholders with expressions of new indices
421428
Array<PrimExpr> new_forward_index;
422429
for (const auto &e : forward_index_) {
423430
PrimExpr cur = e;
424431
for (size_t i = 0; i < InputShape().size(); ++i) {
425432
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
426433
}
434+
cur = az->Simplify(cur);
427435
new_forward_index.push_back(cur);
428436
}
429-
430437
PrimExpr new_forward_thread = forward_thread_;
431438
for (size_t i = 0; i < InputShape().size(); ++i) {
432439
new_forward_thread = Substitute(new_forward_thread,
433440
{{InputPlaceholder(i), orig_indices[i]}});
434441
}
435-
442+
new_forward_thread = az->Simplify(new_forward_thread);
443+
for (size_t i = 0; i < new_vars.size(); ++i) {
444+
auto var = new_vars[i];
445+
new_forward_index =
446+
Substitute(new_forward_index, {{var, InputPlaceholder(i)}});
447+
new_forward_thread =
448+
Substitute(new_forward_thread, {{var, InputPlaceholder(i)}});
449+
}
436450
Fragment reshaped(shape, new_forward_index, new_forward_thread,
437451
ReplicateExtent(), std::nullopt);
438452
if (thread_range_.defined()) {

0 commit comments

Comments
 (0)