Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace tvm {
namespace tl {

namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kSafeValueMap = "safe_value_map";
static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
static constexpr const char *kCustomWarpSpecialization =
Expand Down
126 changes: 73 additions & 53 deletions src/transform/legalize_safe_memory_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,37 @@ class LeafForFinder : public StmtVisitor {
bool parent_has_child_for_ = false;
};

// We will create a visitor to check BufferLoad and BufferStore nodes
// within this loop body. This visitor will:
// GlobalMemChecker for a BufferLoad/BufferStore node:
// 1. Identify BufferLoad and BufferStore nodes.
// 2. Check if the buffer is in global scope.
// 3. For each index, compare against the buffer's shape.
// If the index might exceed the shape (upper bound too large),
// log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor {

GlobalMemChecker(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
GlobalMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds)
: analyzer_(analyzer),
recursively_collect_conds_(recursively_collect_conds) {}
void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope
// This is because we are writing TilePrograms, where out of bounds
// accesses only happen in the global buffer.
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
}
StmtExprVisitor::VisitExpr_(op);
if (recursively_collect_conds_) {
StmtExprVisitor::VisitExpr_(op);
}
}

void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
}
StmtExprVisitor::VisitStmt_(op);
if (recursively_collect_conds_) {
StmtExprVisitor::VisitStmt_(op);
}
}

// Helper function to determine if a buffer is global
Expand Down Expand Up @@ -109,6 +116,7 @@ struct GlobalMemChecker : public StmtExprVisitor {
}
});
if (!has_variable) {
// If index is a constant, we can skip the check
continue;
}

Expand All @@ -134,23 +142,48 @@ struct GlobalMemChecker : public StmtExprVisitor {
private:
Array<PrimExpr> _conditions;
arith::Analyzer *analyzer_;
bool recursively_collect_conds_;
};

class SafeMemorysRewriter : public StmtExprMutator {
arith::Analyzer *analyzer_;

public:
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map,
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_safe_value_map,
arith::Analyzer *analyzer)
: annotated_padding_map_(std::move(annotated_padding_map)),
: annotated_safe_value_map_(std::move(annotated_safe_value_map)),
analyzer_(analyzer) {}

private:
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

// For Load/Store, we only check the current node, not its children.
// Since rewriter will recursively visit children.
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker(load);
Array<PrimExpr> conditions = checker.GetConditions();

if (conditions.empty()) {
return load;
}

// For loading, we can always use safe value if the access is out of
// bounds
PrimExpr value = load;
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, GetSafeValue(load->buffer));
}
return value;
}

Stmt VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));

GlobalMemChecker checker(analyzer_);
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker(store);
Array<PrimExpr> conditions = checker.GetConditions();

Expand All @@ -172,49 +205,36 @@ class SafeMemorysRewriter : public StmtExprMutator {
return store;
}

auto value = store->value;
if (IsGlobalBuffer(store->buffer)) {
Stmt store_with_conditions = store;
for (auto cond : conditions) {
store_with_conditions = IfThenElse(cond, store_with_conditions);
}
return store_with_conditions;
} else if (isSharedBuffer(store->buffer)) {
PrimExpr value = store->value;
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, GetPadding(store->buffer));
}
store.CopyOnWrite()->value = value;
return store;
} else if (IsLocalBuffer(store->buffer)) {
PrimExpr value = store->value;
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, GetPadding(store->buffer));
}
store.CopyOnWrite()->value = value;
return store;
} else {
LOG(FATAL) << "Check store buffer: " << store->buffer
<< " is not a global or shared or local buffer";
// If a store is out of bounds, we skip the corresponding stmt directly.
Stmt store_with_conditions = store;
for (auto cond : conditions) {
store_with_conditions = IfThenElse(cond, store_with_conditions);
}

return store;
return store_with_conditions;
}

// Handle Call Nodes
// Recursively check Load/Store in the call arguments.
// For example
// T.call_extern("handle", "atomicAddx2", T.address_of(C),
// T.address_of(C_shared))

// NOTE(chaofan): This is currently not the most rigorous solution.
// The check here is primarily intended to handle extern functions like
// atomicAdd, which may involve memory access. Due to their special nature,
// the BufferLoad in their parameters might be used for boundary checks of the
// current statement. The current solution adopts a simplified approach:
// directly applying the boundary constraints of all parameters to the
// statement. While not entirely precise, it addresses most common scenarios.
Stmt VisitStmt_(const EvaluateNode *op) final {
auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op));
auto evaluate = Downcast<Evaluate>(op);

if (const CallNode *call_op = op->value.as<CallNode>()) {
auto call = Downcast<Call>(evaluate->value);
auto call = Downcast<Call>(op->value);
if (call->op == builtin::call_extern()) {
GlobalMemChecker checker(analyzer_);
// For CallExtern, we recursively collect conditions from all children.
// Since we cannot rewrite any BufferLoad in its children (Rewrite will
// cause potential Nullptr exception).
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true);
checker(call);
Array<PrimExpr> conditions = checker.GetConditions();

Expand Down Expand Up @@ -248,15 +268,15 @@ class SafeMemorysRewriter : public StmtExprMutator {
String scope = buffer.scope();
return scope == "global";
}
// Get the padding of the buffer
PrimExpr GetPadding(const Buffer &buffer) {
if (annotated_padding_map_.count(buffer)) {
return annotated_padding_map_[buffer];
// Get the safe value of the buffer
PrimExpr GetSafeValue(const Buffer &buffer) {
if (annotated_safe_value_map_.count(buffer)) {
return annotated_safe_value_map_[buffer];
}
return make_zero(buffer->dtype);
}

Map<Buffer, PrimExpr> annotated_padding_map_;
Map<Buffer, PrimExpr> annotated_safe_value_map_;
};

// Class to legalize safe memory access by transforming them appropriately
Expand Down Expand Up @@ -288,7 +308,7 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) {
SafeMemorysRewriter rewriter(annotated_padding_map_, analyzer_);
SafeMemorysRewriter rewriter(annotated_safe_value_map_, analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size
Expand Down Expand Up @@ -316,16 +336,16 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kPaddingMap)) {
auto map = op->annotations.Get(attr::kPaddingMap)
if (op->annotations.count(attr::kSafeValueMap)) {
auto map = op->annotations.Get(attr::kSafeValueMap)
->as<Map<Var, PrimExpr>>()
.value();
for (const auto &[var, padding] : map) {
for (const auto &[var, safe_value] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block "
<< buffer_data_to_buffer_;
auto buffer = buffer_data_to_buffer_[var];
annotated_padding_map_.Set(buffer, padding);
annotated_safe_value_map_.Set(buffer, safe_value);
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
Expand All @@ -338,7 +358,7 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
}

Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, PrimExpr> annotated_padding_map_;
Map<Buffer, PrimExpr> annotated_safe_value_map_;
};

// Create a pass that legalizes vectorized loops in the IRModule
Expand Down
26 changes: 13 additions & 13 deletions src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class RemapBufferRewriter : public arith::IRMutatorWithAnalyzer {
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

Stmt VisitStmt_(const BlockNode *op) final {
if (op->annotations.count(attr::kPaddingMap)) {
if (op->annotations.count(attr::kSafeValueMap)) {
return RewritePaddingMap(op);
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
Expand All @@ -191,18 +191,18 @@ class RemapBufferRewriter : public arith::IRMutatorWithAnalyzer {
* \return The rewritten block.
*/
Stmt RewritePaddingMap(const BlockNode *op) {
auto padding_map = op->annotations.Get(attr::kPaddingMap);
if (!padding_map) {
auto safe_value_map = op->annotations.Get(attr::kSafeValueMap);
if (!safe_value_map) {
LOG(FATAL) << "Padding map annotation is missing";
}

Map<Var, Var> var_remap = CreateVarRemap();
Map<Var, PrimExpr> new_padding_map = RemapPaddingMap(
Downcast<Map<Var, PrimExpr>>(padding_map.value()), var_remap);
Map<Var, PrimExpr> new_safe_value_map = RemapPaddingMap(
Downcast<Map<Var, PrimExpr>>(safe_value_map.value()), var_remap);

auto block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.Set(attr::kPaddingMap, new_padding_map);
block_ptr->annotations.Set(attr::kSafeValueMap, new_safe_value_map);
return block;
}

Expand All @@ -220,21 +220,21 @@ class RemapBufferRewriter : public arith::IRMutatorWithAnalyzer {

/*!
* \brief Remap the padding map using the variable remapping.
* \param padding_map The original padding map.
* \param safe_value_map The original padding map.
* \param var_remap The variable remapping.
* \return The remapped padding map.
*/
Map<Var, PrimExpr> RemapPaddingMap(const Map<Var, PrimExpr> &padding_map,
Map<Var, PrimExpr> RemapPaddingMap(const Map<Var, PrimExpr> &safe_value_map,
const Map<Var, Var> &var_remap) const {
Map<Var, PrimExpr> new_padding_map;
for (const auto &[var, padding] : padding_map) {
Map<Var, PrimExpr> new_safe_value_map;
for (const auto &[var, padding] : safe_value_map) {
if (var_remap.count(var)) {
new_padding_map.Set(var_remap.at(var), padding);
new_safe_value_map.Set(var_remap.at(var), padding);
} else {
new_padding_map.Set(var, padding);
new_safe_value_map.Set(var, padding);
}
}
return new_padding_map;
return new_safe_value_map;
}

Map<Buffer, Buffer> buffer_remap_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def main(
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)

T.annotate_padding({A_shared: pad_value})
T.annotate_safe_value({A: pad_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]

Expand Down
Loading