Skip to content

Commit 4370309

Browse files
LeiWang1999tilelang-bot
andauthored
[Enhancement] Support Layout/Fragment Reshape (#1241)
* Update layout handling and introduce reshape functionality - Updated the `LayoutNode` class to include a new `Reshape` method, allowing for dynamic reshaping of layouts based on input shapes. - Enhanced the `OutputShape` method to provide better handling of cases where the analyzer cannot form an `IntervalSet`, implementing fallback mechanisms to ensure safe extents. - Refactored the `ReduceOpNode` to utilize `BufferRegion` for improved memory handling during reduction operations. - Added tests for reshaping functionality and layout transformations to ensure correctness and performance in various scenarios. * lint fix * Revert tvm submodule pointer to 1815c3e0b6ec4ead36370bbd1562025d8529017c; keep src unchanged * Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1 * Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1 * remove useless prove * remove comment --------- Co-authored-by: tilelang-bot <bot@tilelang>
1 parent 02cfc2a commit 4370309

File tree

8 files changed

+611
-36
lines changed

8 files changed

+611
-36
lines changed

3rdparty/tvm

Submodule tvm updated from 1815c3e to 093b2cd

src/layout/layout.cc

Lines changed: 163 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
102102
for (size_t i = 0; i < ret.size(); i++) {
103103
auto ist = analyzer.int_set(forward_index_[i] + 1);
104104
if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
105-
// X-OR Expression
106-
ret.Set(i, input_size_[i]);
105+
// Analyzer couldn't form an IntervalSet (e.g. bitwise ops).
106+
// Fall back to ConstIntBound to derive a safe extent.
107+
auto cib = analyzer.const_int_bound(forward_index_[i]);
108+
if (cib->min_value != arith::ConstIntBound::kNegInf &&
109+
cib->max_value != arith::ConstIntBound::kPosInf &&
110+
cib->min_value >= 0) {
111+
// extent = max - min + 1, using 64-bit integer literal
112+
ret.Set(i, Integer(cib->max_value - cib->min_value + 1));
113+
} else {
114+
// Last-resort conservative fallback to avoid OOB/crash
115+
// Prefer to keep dimension from known input_size_ if available.
116+
if (i < input_size_.size()) {
117+
ret.Set(i, input_size_[i]);
118+
} else {
119+
ret.Set(i, Integer(1));
120+
}
121+
}
107122
} else {
108-
// CHECK(is_one(ist.min())) << ist.min();
109123
ret.Set(i, ist.max());
110124
}
111125
}
@@ -282,10 +296,156 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
282296
return {Layout(outputs_shape, backward_index), level};
283297
}
284298

299+
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
300+
arith::Analyzer *analyzer) const {
301+
// Fast path: if shape is the same, return the original layout
302+
if (StructuralEqual()(InputShape(), shape)) {
303+
return ffi::GetRef<Layout>(this);
304+
}
305+
306+
// Step 1. Prove the product of InputShape is equal to the product of shape
307+
PrimExpr input_shape_product = Integer(1);
308+
for (const auto &dim : InputShape()) {
309+
input_shape_product *= dim;
310+
}
311+
PrimExpr shape_product = Integer(1);
312+
for (const auto &dim : shape) {
313+
shape_product *= dim;
314+
}
315+
316+
if (analyzer) {
317+
ICHECK(analyzer->CanProveEqual(input_shape_product, shape_product))
318+
<< "InputShape() = " << InputShape() << " shape = " << shape;
319+
} else {
320+
arith::Analyzer local_analyzer;
321+
ICHECK(local_analyzer.CanProveEqual(input_shape_product, shape_product))
322+
<< "InputShape() = " << InputShape() << " shape = " << shape;
323+
}
324+
325+
// Step 2. Create new forward indices by reshaping
326+
// For each dimension in the new shape, we create a placeholder variable
327+
Array<Var> new_vars;
328+
for (size_t i = 0; i < shape.size(); ++i) {
329+
new_vars.push_back(InputPlaceholder(i));
330+
}
331+
// Step 3. Compute the flat index from new shape indices
332+
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
333+
PrimExpr flat_index = Integer(0);
334+
for (size_t i = 0; i < shape.size(); ++i) {
335+
PrimExpr stride = Integer(1);
336+
for (size_t j = i + 1; j < shape.size(); ++j) {
337+
stride = stride * shape[j];
338+
}
339+
flat_index = flat_index + new_vars[i] * stride;
340+
}
341+
// Step 4. Convert flat index back to original shape indices
342+
// For original shape [s0, s1, ..., sm]:
343+
// i0 = flat_index // (s1 * s2 * ... * sm)
344+
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
345+
// ...
346+
Array<PrimExpr> original_indices;
347+
PrimExpr remaining = flat_index;
348+
for (size_t i = 0; i < InputShape().size(); ++i) {
349+
PrimExpr stride = Integer(1);
350+
for (size_t j = i + 1; j < InputShape().size(); ++j) {
351+
stride = stride * InputShape()[j];
352+
}
353+
original_indices.push_back(floordiv(remaining, stride));
354+
remaining = floormod(remaining, stride);
355+
}
356+
// Step 5. Substitute original indices into forward_index_
357+
Array<PrimExpr> new_forward_index;
358+
for (const auto &fwd_expr : forward_index_) {
359+
PrimExpr substituted = fwd_expr;
360+
// Replace each InputPlaceholder(i) with original_indices[i]
361+
for (size_t i = 0; i < InputShape().size(); ++i) {
362+
substituted =
363+
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
364+
}
365+
new_forward_index.push_back(substituted);
366+
}
367+
return Layout(shape, new_forward_index);
368+
}
369+
370+
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
371+
arith::Analyzer *analyzer) const {
372+
// Fast path: identical input shape, return self
373+
if (StructuralEqual()(InputShape(), shape)) {
374+
return ffi::GetRef<Fragment>(this);
375+
}
376+
377+
// 1) Prove total number of elements remains the same
378+
PrimExpr input_prod = Integer(1);
379+
for (const auto &d : InputShape())
380+
input_prod *= d;
381+
PrimExpr shape_prod = Integer(1);
382+
for (const auto &d : shape)
383+
shape_prod *= d;
384+
385+
if (analyzer) {
386+
ICHECK(analyzer->CanProveEqual(input_prod, shape_prod))
387+
<< "InputShape() = " << InputShape() << " shape = " << shape
388+
<< " input fragment layout is = " << DebugOutput();
389+
} else {
390+
arith::Analyzer local_analyzer;
391+
ICHECK(local_analyzer.CanProveEqual(input_prod, shape_prod))
392+
<< "InputShape() = " << InputShape() << " shape = " << shape;
393+
}
394+
395+
// 2) Build flat index from new-shape indices
396+
Array<Var> new_vars;
397+
new_vars.reserve(shape.size());
398+
for (size_t i = 0; i < shape.size(); ++i)
399+
new_vars.push_back(InputPlaceholder(i));
400+
401+
PrimExpr flat = Integer(0);
402+
for (size_t i = 0; i < shape.size(); ++i) {
403+
PrimExpr stride = Integer(1);
404+
for (size_t j = i + 1; j < shape.size(); ++j)
405+
stride = stride * shape[j];
406+
flat = flat + new_vars[i] * stride;
407+
}
408+
409+
// 3) Recover original indices from flat index
410+
Array<PrimExpr> orig_indices;
411+
PrimExpr remain = flat;
412+
for (size_t i = 0; i < InputShape().size(); ++i) {
413+
PrimExpr stride = Integer(1);
414+
for (size_t j = i + 1; j < InputShape().size(); ++j)
415+
stride = stride * InputShape()[j];
416+
orig_indices.push_back(floordiv(remain, stride));
417+
remain = floormod(remain, stride);
418+
}
419+
420+
// 4) Substitute old placeholders with expressions of new indices
421+
Array<PrimExpr> new_forward_index;
422+
for (const auto &e : forward_index_) {
423+
PrimExpr cur = e;
424+
for (size_t i = 0; i < InputShape().size(); ++i) {
425+
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
426+
}
427+
new_forward_index.push_back(cur);
428+
}
429+
430+
PrimExpr new_forward_thread = forward_thread_;
431+
for (size_t i = 0; i < InputShape().size(); ++i) {
432+
new_forward_thread = Substitute(new_forward_thread,
433+
{{InputPlaceholder(i), orig_indices[i]}});
434+
}
435+
436+
Fragment reshaped(shape, new_forward_index, new_forward_thread,
437+
ReplicateExtent(), std::nullopt);
438+
if (thread_range_.defined()) {
439+
reshaped = reshaped->BindThreadRange(thread_range_);
440+
}
441+
return reshaped;
442+
}
443+
285444
Layout LayoutNode::Inverse() const {
286445
auto inverse_result = InverseWithLevel();
287446
return std::move(inverse_result.first);
288447
}
448+
289449
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
290450
const PrimExpr &forward_thread,
291451
arith::Analyzer *analyzer) {

src/layout/layout.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class LayoutNode : public Object {
4141
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
4242

4343
virtual Layout Inverse() const;
44+
45+
virtual Layout Reshape(const Array<PrimExpr> &shape,
46+
arith::Analyzer *analyzer) const;
47+
4448
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
4549

4650
virtual std::string DebugOutput() const;
@@ -81,6 +85,9 @@ class FragmentNode : public LayoutNode {
8185
Array<PrimExpr> GetForwardVars() const final;
8286

8387
Layout Inverse() const final;
88+
89+
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;
90+
8491
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
8592

8693
PrimExpr ThreadExtent() const;

src/op/reduce.cc

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,62 @@
1414
#include "../op/parallel.h"
1515
#include "../target/utils.h"
1616
#include "../transform/loop_partition.h"
17+
#include "region.h"
1718
#include "tir/transforms/ir_utils.h"
1819

1920
namespace tvm {
2021
namespace tl {
2122

2223
using namespace tir;
2324

25+
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
26+
// to BufferRegion so Reduce can uniformly consume regions.
27+
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
28+
const BufferMap &vmap) {
29+
// Case 1: Already a BufferRegion
30+
if (arg->IsInstance<BufferRegionNode>()) {
31+
return Downcast<BufferRegion>(arg);
32+
}
33+
34+
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
35+
// extent=1)
36+
if (const auto *load = arg.as<BufferLoadNode>()) {
37+
Array<Range> ranges;
38+
for (const PrimExpr &index : load->indices) {
39+
if (const auto *ramp = index.as<RampNode>()) {
40+
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
41+
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
42+
<< "Only stride-1 Ramp is supported in region conversion";
43+
ICHECK(ramp->lanes.as<IntImmNode>())
44+
<< "Scalable vector lanes not supported in region conversion";
45+
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
46+
} else {
47+
ranges.push_back(Range::FromMinExtent(index, 1));
48+
}
49+
}
50+
return BufferRegion(load->buffer, ranges);
51+
}
52+
53+
// Case 3: Call nodes (only tl.region)
54+
if (const auto *call = arg.as<CallNode>()) {
55+
// tl.region(...) — reconstruct via RegionOp
56+
if (call->op.same_as(RegionOp::Get())) {
57+
RegionOp region(call->args, vmap);
58+
return BufferRegion(region->GetBuffer(), region->GetRanges());
59+
}
60+
}
61+
62+
LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
63+
throw; // Unreachable
64+
}
65+
2466
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
2567
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
26-
node->src = vmap[GetVarFromAccessPtr(args[0])];
27-
node->dst = vmap[GetVarFromAccessPtr(args[1])];
68+
// Accept BufferRegion/BufferLoad/tl.region for src/dst
69+
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
70+
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
71+
node->src = node->srcRegion_->buffer;
72+
node->dst = node->dstRegion_->buffer;
2873
std::string reduce_type = args[2].as<StringImm>().value()->value;
2974
node->dim = args[3].as<IntImm>().value()->value;
3075
node->type = ReduceType(reduce_type);
@@ -369,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
369414
InferLevel level) const {
370415
if (level >= InferLevel::kStrict)
371416
return {};
417+
372418
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
373419
T.layout_map.count(src)) {
374420
auto src_layout = T.layout_map[src].as<Fragment>().value();
@@ -422,6 +468,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
422468
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
423469
->CondenseReplicateVar()
424470
->BindThreadRange(T.thread_bounds);
471+
425472
if (!T.layout_map.count(dst))
426473
return {{dst, dst_layout}};
427474
else {

src/op/reduce.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,11 @@ class ReduceType : public ObjectRef {
8282
class ReduceOpNode : public TileOperatorNode {
8383
public:
8484
tir::Buffer src, dst; ///< Source and destination buffers
85-
int dim; ///< Dimension to reduce along
86-
ReduceType type; ///< Type of reduction operation
87-
bool clear; ///< Whether to clear destination before reduction
85+
// Optional: keep the original regions used to construct this op
86+
BufferRegion srcRegion_, dstRegion_;
87+
int dim; ///< Dimension to reduce along
88+
ReduceType type; ///< Type of reduction operation
89+
bool clear; ///< Whether to clear destination before reduction
8890

8991
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode,
9092
TileOperatorNode);
@@ -94,6 +96,8 @@ class ReduceOpNode : public TileOperatorNode {
9496
refl::ObjectDef<ReduceOpNode>()
9597
.def_ro("src", &ReduceOpNode::src)
9698
.def_ro("dst", &ReduceOpNode::dst)
99+
.def_ro("srcRegion", &ReduceOpNode::srcRegion_)
100+
.def_ro("dstRegion", &ReduceOpNode::dstRegion_)
97101
.def_ro("dim", &ReduceOpNode::dim)
98102
.def_ro("type", &ReduceOpNode::type)
99103
.def_ro("clear", &ReduceOpNode::clear);

0 commit comments

Comments
 (0)