Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
(par_op)->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
(par_op)->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Expand Down
364 changes: 233 additions & 131 deletions src/op/copy.cc

Large diffs are not rendered by default.

60 changes: 47 additions & 13 deletions src/op/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ using namespace tir;

/// Copy instruction types for different memory access patterns
enum class CopyInst : uint8_t {
kNormal = 0, ///< Standard memory copy (ldg/stg/cpasync)
kLDSM = 1, ///< Load matrix instruction
kSTSM = 2, ///< Store matrix instruction
kBulkLoad = 3, ///< Tensor Memory Access load
kBulkStore = 4, ///< Tensor Memory Access store
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
// we should separate the bulk load and store for 1d and multi-dim
// as they have different memory access patterns
kBulkLoad1D = 5, // utilize tma load 1d
kBulkStore1D = 6, // utilize tma store 1d
};

/// Descriptor for Tensor Memory Access (TMA) copy operations
Expand Down Expand Up @@ -137,17 +141,41 @@ class CopyNode : public TileOperatorNode {
* \param T Arguments for layout inference.
* \param level Level of inference (basic or detailed).
*/
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;

/*!
* \brief Check if bulk copy is supported.
*/
bool CheckBulkLoad(Target target) const;
bool CheckBulkLoad(Target target, arith::Analyzer *analyzer,
bool check_last_dim = true) const;

/*!
* \brief Check if bulk store is supported.
*/
bool CheckBulkStore(Target target) const;
bool CheckBulkStore(Target target, arith::Analyzer *analyzer,
bool check_last_dim = true) const;

/*!
* \brief Check if bulk copy 1d load is supported.
*/
bool CheckBulkLoad1D(Target target, const LayoutMap &layout_map,
arith::Analyzer *analyzer) const;

/*!
* \brief Check if bulk copy 1d store is supported.
*/
bool CheckBulkStore1D(Target target, const LayoutMap &layout_map,
arith::Analyzer *analyzer) const;

/*!
* \brief Check if bulk copy 1d is supported.
*/
bool CheckBulkCopy1D(const Buffer &global_tensor, const Buffer &shared_tensor,
const Array<Range> &global_range,
const Array<Range> &shared_range,
const LayoutMap &layout_map,
arith::Analyzer *analyzer) const;

/*!
* \brief Check if lds memory copy is supported.
Expand All @@ -162,18 +190,23 @@ class CopyNode : public TileOperatorNode {
/*!
* \brief Get the copy instruction type.
*/
CopyInst GetCopyInst(Target target, bool disable_tma_lower) const;
CopyInst GetCopyInst(Target target, bool disable_tma_lower,
const LayoutMap &layout_map, arith::Analyzer *analyzer,
bool buffer_oob) const;

/*!
* \brief Clone this copy operator.
*/
protected:
/*!
* \brief Generate lowering for bulk/global-to-shared copy.
*/
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;

/*!
* \brief Generate lowering for bulk copy 1d.
*/
Stmt LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;

/*!
* \brief Generate lowering for LDS Memory Copy (shared memory to shared
* memory or smem usage).
Expand Down Expand Up @@ -316,7 +349,8 @@ class Conv2DIm2ColOpNode : public TileOperatorNode {
/*!
* \brief Infer layout for this operator.
*/
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;

/*!
* \brief Get TVM Op handle.
Expand Down
12 changes: 5 additions & 7 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree);
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
Expand All @@ -189,7 +188,8 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
Expand Down Expand Up @@ -225,9 +225,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TVM_FFI_STATIC_INIT_BLOCK({
FillNode::RegisterReflection();
});
TVM_FFI_STATIC_INIT_BLOCK({ FillNode::RegisterReflection(); });

} // namespace tl
} // namespace tvm
2 changes: 2 additions & 0 deletions src/op/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ struct LayoutInferArgs {
Target target;
Range thread_bounds;
LayoutMap layout_map;
arith::Analyzer *analyzer;
bool buffer_oob = false;
Map<Buffer, Buffer> buffer_remap;
};

Expand Down
54 changes: 50 additions & 4 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#include <queue>

#include "../layout/utils.h"
#include "../op/copy.h"
#include "../op/parallel.h"
#include "../op/region.h"

#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h"
Expand Down Expand Up @@ -64,6 +66,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {}

using arith::IRVisitorWithAnalyzer::IRVisitorWithAnalyzer;

void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
LayoutMap &layout_map, const LayoutMap &strict_layout_map,
std::queue<int> &q, std::vector<bool> &in_queue) {
Expand All @@ -80,6 +84,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
auto buffer_oob = buffer_oob_vec_[cur_infer_id];
// Double-check that 'next' is valid
ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
<< "] is null inside run_infer_step.";
Expand All @@ -100,8 +105,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
"required for layout inference.";

// Run InferLayout
auto updates = next->InferLayout(
LayoutInferArgs{target_, thread_bounds, layout_map}, level);
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
level);

// Process the returned updates
for (const auto &[buffer, layout] : updates) {
Expand Down Expand Up @@ -199,6 +206,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length.";
ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length.";

// If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup.
Expand Down Expand Up @@ -306,8 +316,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
addToUseList(buffer.value());
}
}
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p));
// Compute thread_var_ and thread_bounds_
thread_var_vec_.push_back(thread_var_);
if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
Expand All @@ -320,6 +329,39 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}

// Compute buffer oob for each buffer in the op
if (const auto *copy = p.as<CopyNode>()) {
auto src_tensor = copy->src;
auto dst_tensor = copy->dst;
auto src_range = copy->src_range;
auto dst_range = copy->dst_range;
bool src_oob = false;
bool dst_oob = false;
for (size_t i = 0; i < src_range.size(); i++) {
if (!analyzer_.CanProve(src_range[i]->min + src_range[i]->extent <=
src_tensor->shape[i],
arith::ProofStrength::kSymbolicBound)) {
src_oob = true;
break;
}
}
for (size_t i = 0; i < dst_range.size(); i++) {
if (!analyzer_.CanProve(dst_range[i]->min + dst_range[i]->extent <=
dst_tensor->shape[i],
arith::ProofStrength::kSymbolicBound)) {
dst_oob = true;
break;
}
}
buffer_oob_vec_.push_back(src_oob || dst_oob);
} else {
buffer_oob_vec_.push_back(false);
}

// Add the tile operator to infer_list_
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p));
}
}

Expand Down Expand Up @@ -365,6 +407,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}
buffer_oob_vec_.push_back(false);
} else {
IRVisitorWithAnalyzer::VisitStmt(op->body);
}
Expand Down Expand Up @@ -411,6 +454,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
IterVarType::kDataPar);
std::vector<IterVar> thread_var_vec_;
std::vector<Range> thread_bounds_vec_;
std::vector<bool> buffer_oob_vec_;
Target target_;
LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false};
Expand Down Expand Up @@ -556,6 +600,8 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
: arith::IRMutatorWithAnalyzer(analyzer), result_(result),
skip_thread_partition_(skip_thread_partition){};

using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

/**
* @brief Visit and mutate a Block node to attach inferred layout information.
*
Expand Down
Loading
Loading