Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 1815c3 to 093b2c
166 changes: 163 additions & 3 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
for (size_t i = 0; i < ret.size(); i++) {
auto ist = analyzer.int_set(forward_index_[i] + 1);
if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
// X-OR Expression
ret.Set(i, input_size_[i]);
// Analyzer couldn't form an IntervalSet (e.g. bitwise ops).
// Fall back to ConstIntBound to derive a safe extent.
auto cib = analyzer.const_int_bound(forward_index_[i]);
if (cib->min_value != arith::ConstIntBound::kNegInf &&
cib->max_value != arith::ConstIntBound::kPosInf &&
cib->min_value >= 0) {
// extent = max - min + 1, using 64-bit integer literal
ret.Set(i, Integer(cib->max_value - cib->min_value + 1));
} else {
// Last-resort conservative fallback to avoid OOB/crash
// Prefer to keep dimension from known input_size_ if available.
if (i < input_size_.size()) {
ret.Set(i, input_size_[i]);
} else {
ret.Set(i, Integer(1));
}
}
} else {
Comment on lines 104 to 122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Improved fallback logic prevents crashes but may be overly conservative.

The updated OutputShape() fallback (lines 104-122) improves robustness when the analyzer cannot form an IntervalSet:

Improvements:

  1. Uses const_int_bound to derive safe extents when possible (lines 107-112)
  2. Falls back to input_size_[i] or 1 as last resort (lines 115-120)
  3. Preserves the CHECK(is_one(ist.min())) assertion for the normal path (line 123)

Concern: The fallback to input_size_[i] or 1 (lines 116-119) is conservative but may produce incorrect shapes for complex layouts. Consider whether a runtime error or more sophisticated analysis would be preferable to silently using potentially incorrect dimensions.

Verify that the fallback behavior is acceptable for all layout types, especially when forward_index_[i] involves bitwise operations or other constructs the analyzer cannot handle.


🏁 Script executed:

#!/bin/bash
# Search for layouts that might trigger the fallback path
rg -n -C3 'OutputShape|forward_index_' --type=cpp | grep -E '(bitwise|&|\||\^)'

Length of output: 2812


🏁 Script executed:

#!/bin/bash
# Search for forward_index_ definitions and assignments to see what expressions are stored
rg -n 'forward_index_\s*=' --type=cpp -A 2 | head -100

Length of output: 539


🏁 Script executed:

#!/bin/bash
# Look for actual bitwise operations in layout files
rg -nP '\s+\d+\s*(&|\||\^|<<|>>)\s*\d+' --type=cpp src/layout/

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find test files related to layout or OutputShape
fd -e cc -e h 'test|layout' --type f | grep -i test | head -20

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for calls to OutputShape to see how result is used
rg -n 'OutputShape\(\)' --type=cpp -A 2 -B 1 | head -80

Length of output: 4699


🏁 Script executed:

#!/bin/bash
# Find input_size_ definition and its type
rg -n 'input_size_' --type=cpp -B 2 -A 2 | grep -E '(Array|vector|input_size_|PrimExpr)' | head -40

Length of output: 2304


🏁 Script executed:

#!/bin/bash
# Check context around the ICHECK that uses OutputShape
rg -n 'as_const_int.*OutputShape' --type=cpp -B 5 -A 5

Length of output: 632


🏁 Script executed:

#!/bin/bash
# Search for input_size_ member variable declaration in header files
rg -n 'input_size_' --type=h -B 1 -A 1

Length of output: 402


🏁 Script executed:

#!/bin/bash
# Look for other places where OutputShape result is used without as_const_int conversion
rg -n 'OutputShape\(\)\[' --type=cpp -B 2 -A 2

Length of output: 1455


🏁 Script executed:

#!/bin/bash
# Search for Layout/Fragment construction to see what input_size values are passed
rg -n 'Layout\(|Fragment\(' --type=cpp -A 2 | head -60

Length of output: 3885


🏁 Script executed:

#!/bin/bash
# Check if input_size_ elements are checked/validated to be const before being returned
rg -n 'input_size_\[' --type=cpp -B 3 | head -50

Length of output: 848


🏁 Script executed:

#!/bin/bash
# Look at the full context of layout_inference.cc lines 836-842 to understand the crash scenario
rg -n 'frag_reg_num' --type=cpp -B 8 -A 5

Length of output: 1497


Fix the OutputShape() fallback to ensure it always returns const integers, or validate input_size_ at construction time.

The fallback logic (lines 104-122) has a critical bug: it returns input_size_[i] (line 117), which may contain non-const PrimExpr since input_size_ is populated from sources like buffer->shape (see layout_reducer.cc:216-217). This violates the implicit contract that OutputShape() returns const dimensions.

Specifically, layout_inference.cc:839 will crash with an ICHECK failure:

for (auto i : frag.value()->OutputShape()) {
  auto pci = as_const_int(i);
  ICHECK(pci != nullptr);  // crashes if OutputShape() returns symbolic expr
  frag_reg_num *= *pci;
}

Fix: Either (1) validate that input_size_ elements are const at construction time, (2) ensure the fallback path only returns proven const dimensions (e.g., always Integer(1) instead of input_size_[i]), or (3) add a precondition guard preventing this fallback when input_size_ contains symbolic expressions.

// CHECK(is_one(ist.min())) << ist.min();
ret.Set(i, ist.max());
}
}
Expand Down Expand Up @@ -282,10 +296,156 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
return {Layout(outputs_shape, backward_index), level};
}

Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this);
}

// Step 1. Prove the product of InputShape is equal to the product of shape
PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) {
input_shape_product *= dim;
}
PrimExpr shape_product = Integer(1);
for (const auto &dim : shape) {
shape_product *= dim;
}

if (analyzer) {
ICHECK(analyzer->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
} else {
arith::Analyzer local_analyzer;
ICHECK(local_analyzer.CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
}

// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
Array<Var> new_vars;
for (size_t i = 0; i < shape.size(); ++i) {
new_vars.push_back(InputPlaceholder(i));
}
// Step 3. Compute the flat index from new shape indices
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
PrimExpr flat_index = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j) {
stride = stride * shape[j];
}
flat_index = flat_index + new_vars[i] * stride;
}
// Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ...
Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) {
stride = stride * InputShape()[j];
}
original_indices.push_back(floordiv(remaining, stride));
remaining = floormod(remaining, stride);
}
// Step 5. Substitute original indices into forward_index_
Array<PrimExpr> new_forward_index;
for (const auto &fwd_expr : forward_index_) {
PrimExpr substituted = fwd_expr;
// Replace each InputPlaceholder(i) with original_indices[i]
for (size_t i = 0; i < InputShape().size(); ++i) {
substituted =
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
}
new_forward_index.push_back(substituted);
}
return Layout(shape, new_forward_index);
}

Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this);
}

// 1) Prove total number of elements remains the same
PrimExpr input_prod = Integer(1);
for (const auto &d : InputShape())
input_prod *= d;
PrimExpr shape_prod = Integer(1);
for (const auto &d : shape)
shape_prod *= d;

if (analyzer) {
ICHECK(analyzer->CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< " input fragment layout is = " << DebugOutput();
} else {
arith::Analyzer local_analyzer;
ICHECK(local_analyzer.CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape;
}

// 2) Build flat index from new-shape indices
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i)
new_vars.push_back(InputPlaceholder(i));

PrimExpr flat = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j)
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}

// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j)
stride = stride * InputShape()[j];
orig_indices.push_back(floordiv(remain, stride));
remain = floormod(remain, stride);
}

// 4) Substitute old placeholders with expressions of new indices
Array<PrimExpr> new_forward_index;
for (const auto &e : forward_index_) {
PrimExpr cur = e;
for (size_t i = 0; i < InputShape().size(); ++i) {
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
}
new_forward_index.push_back(cur);
}

PrimExpr new_forward_thread = forward_thread_;
for (size_t i = 0; i < InputShape().size(); ++i) {
new_forward_thread = Substitute(new_forward_thread,
{{InputPlaceholder(i), orig_indices[i]}});
}

Fragment reshaped(shape, new_forward_index, new_forward_thread,
ReplicateExtent(), std::nullopt);
if (thread_range_.defined()) {
reshaped = reshaped->BindThreadRange(thread_range_);
}
return reshaped;
}

Layout LayoutNode::Inverse() const {
auto inverse_result = InverseWithLevel();
return std::move(inverse_result.first);
}

PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
const PrimExpr &forward_thread,
arith::Analyzer *analyzer) {
Expand Down
7 changes: 7 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class LayoutNode : public Object {
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;

virtual Layout Inverse() const;

virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const;

virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;

virtual std::string DebugOutput() const;
Expand Down Expand Up @@ -81,6 +85,9 @@ class FragmentNode : public LayoutNode {
Array<PrimExpr> GetForwardVars() const final;

Layout Inverse() const final;

Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;

std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;

PrimExpr ThreadExtent() const;
Expand Down
51 changes: 49 additions & 2 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,62 @@
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so Reduce can uniformly consume regions.
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}

// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}

// Case 3: Call nodes (only tl.region)
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
}

LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable
}

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
// Accept BufferRegion/BufferLoad/tl.region for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
std::string reduce_type = args[2].as<StringImm>().value()->value;
node->dim = args[3].as<IntImm>().value()->value;
node->type = ReduceType(reduce_type);
Expand Down Expand Up @@ -369,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (level >= InferLevel::kStrict)
return {};

if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src)) {
auto src_layout = T.layout_map[src].as<Fragment>().value();
Expand Down Expand Up @@ -422,6 +468,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds);

if (!T.layout_map.count(dst))
return {{dst, dst_layout}};
else {
Expand Down
10 changes: 7 additions & 3 deletions src/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ class ReduceType : public ObjectRef {
class ReduceOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension to reduce along
ReduceType type; ///< Type of reduction operation
bool clear; ///< Whether to clear destination before reduction
// Optional: keep the original regions used to construct this op
BufferRegion srcRegion_, dstRegion_;
int dim; ///< Dimension to reduce along
ReduceType type; ///< Type of reduction operation
bool clear; ///< Whether to clear destination before reduction

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode,
TileOperatorNode);
Expand All @@ -94,6 +96,8 @@ class ReduceOpNode : public TileOperatorNode {
refl::ObjectDef<ReduceOpNode>()
.def_ro("src", &ReduceOpNode::src)
.def_ro("dst", &ReduceOpNode::dst)
.def_ro("srcRegion", &ReduceOpNode::srcRegion_)
.def_ro("dstRegion", &ReduceOpNode::dstRegion_)
.def_ro("dim", &ReduceOpNode::dim)
.def_ro("type", &ReduceOpNode::type)
.def_ro("clear", &ReduceOpNode::clear);
Expand Down
Loading
Loading