Skip to content

Commit bddb125

Browse files
authored
[Language] Support tilelang alloc_var(dtype, init=x) (#1092)
* - carry existing local-var initializer map into OpaqueBlockLower, reattach it to generated Allocates and the PrimFunc attrs - thread the map through FlattenBuffer and StorageRewrite so flattened/merged allocations keep their tl.local_var_init annotations - teach annotation handling to accept scalar initializers, resolve buffers, and merge with existing stat * lint fix * enhance * lint fix * lint fix
1 parent cdc67fc commit bddb125

File tree

7 files changed

+260
-32
lines changed

7 files changed

+260
-32
lines changed

src/op/builtin.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ static constexpr const char *kWarpSpecializationScope =
2727
"kWarpSpecializationScope";
2828
static constexpr const char *kCustomWarpSpecialization =
2929
"kCustomWarpSpecialization";
30+
static constexpr const char *kLocalVarInit = "tl.local_var_init";
3031
} // namespace attr
3132

3233
static constexpr const char *kDebugMergeSharedMemoryAllocations =

src/target/codegen_cuda.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,8 +2201,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
22012201
} else if (scope == "local") {
22022202
stream << ' ' << vid << '[' << constant_size << "];\n";
22032203
} else if (scope == "local.var") {
2204-
stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0))
2205-
<< ";\n";
2204+
PrimExpr init = tir::make_const(op->dtype, 0);
2205+
auto init_it = op->annotations.find(tl::attr::kLocalVarInit);
2206+
if (init_it != op->annotations.end()) {
2207+
PrimExpr user_init = Downcast<PrimExpr>((*init_it).second);
2208+
if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) {
2209+
user_init = tir::Cast(op->dtype, user_init);
2210+
}
2211+
init = user_init;
2212+
}
2213+
stream << ' ' << vid << " = " << PrintExpr(init) << ";\n";
22062214
} else if (scope != "local.descriptor") {
22072215
ICHECK(false) << "Unsupported scope: " << scope;
22082216
}

src/transform/flatten_buffer.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
#include "tir/transforms/ir_utils.h"
2626
#include <tvm/arith/iter_affine_map.h>
2727
#include <tvm/ffi/reflection/registry.h>
28+
#include <tvm/ir/attrs.h>
2829
#include <tvm/tir/analysis.h>
2930
#include <tvm/tir/data_type_rewriter.h>
3031
#include <tvm/tir/stmt_functor.h>
3132
#include <tvm/tir/transform.h>
3233

3334
#include <utility>
3435

36+
#include "../op/builtin.h"
37+
3538
namespace tvm {
3639
namespace tl {
3740

@@ -46,6 +49,10 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
4649
static PrimFunc Flatten(PrimFunc func) {
4750
arith::Analyzer ana;
4851
auto pass = BufferFlattener(&ana);
52+
if (auto init_map =
53+
func->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
54+
pass.local_var_init_map_ = init_map.value();
55+
}
4956
auto writer = func.CopyOnWrite();
5057
pass.MarkBufferMapShapes(func);
5158
writer->body = pass.VisitStmt(func->body);
@@ -198,6 +205,13 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
198205
if (!new_extents.same_as(alloc->extents)) {
199206
alloc.CopyOnWrite()->extents = new_extents;
200207
}
208+
if (!local_var_init_map_.empty()) {
209+
auto init_it = local_var_init_map_.find(alloc->buffer_var);
210+
if (init_it != local_var_init_map_.end()) {
211+
const PrimExpr &init = (*init_it).second;
212+
alloc.CopyOnWrite()->annotations.Set(tl::attr::kLocalVarInit, init);
213+
}
214+
}
201215

202216
return std::move(alloc);
203217
}
@@ -354,6 +368,9 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
354368

355369
/*! \brief The updated external buffer map. */
356370
Map<Var, Buffer> updated_extern_buffer_map_;
371+
372+
/*! \brief Local var initializers preserved from block annotations. */
373+
Map<Var, PrimExpr> local_var_init_map_;
357374
};
358375

359376
PrimFunc FlattenBufferRewriter(PrimFunc f) {

src/transform/lower_opaque_block.cc

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@
2222
*/
2323

2424
#include <tvm/ffi/reflection/registry.h>
25+
#include <tvm/ir/attrs.h>
2526
#include <tvm/tir/stmt_functor.h>
2627
#include <tvm/tir/transform.h>
2728

29+
#include <string>
2830
#include <utility>
2931

32+
#include "../op/builtin.h"
3033
#include "tir/transforms/ir_utils.h"
3134

3235
namespace tvm {
@@ -39,10 +42,20 @@ using namespace tir::attr;
3942
*/
4043
class OpaqueBlockLower : public StmtExprMutator {
4144
public:
42-
static Stmt Rewrite(Stmt body) {
45+
static PrimFunc Rewrite(PrimFunc f) {
46+
auto fptr = f.CopyOnWrite();
4347
OpaqueBlockLower lower;
44-
lower.storage_align_ = CollectStorageAlignAnnotation(body);
45-
return lower(std::move(body));
48+
if (auto existing =
49+
fptr->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
50+
lower.local_var_init_map_ = existing.value();
51+
}
52+
lower.storage_align_ = CollectStorageAlignAnnotation(fptr->body);
53+
fptr->body = lower(std::move(fptr->body));
54+
if (!lower.local_var_init_map_.empty()) {
55+
f = WithAttr(std::move(f), tl::attr::kLocalVarInit,
56+
lower.local_var_init_map_);
57+
}
58+
return f;
4659
}
4760

4861
private:
@@ -59,7 +72,13 @@ class OpaqueBlockLower : public StmtExprMutator {
5972
if (!is_one(predicate)) {
6073
body = IfThenElse(predicate, std::move(body));
6174
}
62-
// Step 3. Handle allocations in reverse order
75+
// Step 3. Handle annotations, block annotations are not preserved by
76+
// default.
77+
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
78+
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true,
79+
new_block->alloc_buffers);
80+
81+
// Step 4. Handle allocations in reverse order
6382
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
6483
const Buffer &buffer = new_block->alloc_buffers[i - 1];
6584
Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
@@ -74,14 +93,15 @@ class OpaqueBlockLower : public StmtExprMutator {
7493
}
7594
allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns);
7695
}
77-
96+
auto init_it = local_var_init_map_.find(buffer->data);
97+
if (init_it != local_var_init_map_.end()) {
98+
const PrimExpr &init = (*init_it).second;
99+
allocate_annotations.Set(tl::attr::kLocalVarInit, init);
100+
}
78101
body = Allocate(buffer->data, buffer->dtype, allocation_shape,
79102
const_true(), std::move(body), allocate_annotations);
80103
}
81-
// Step 4. Handle annotations, block annotations are not preserved by
82-
// default.
83-
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
84-
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true);
104+
// Step 5. Insert attribute statements converted from pragmas
85105
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
86106
body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
87107
}
@@ -188,13 +208,34 @@ class OpaqueBlockLower : public StmtExprMutator {
188208
Map<String, ffi::Any>
189209
HandleAnnotations(const Map<String, ffi::Any> &annotations,
190210
std::vector<std::pair<std::string, PrimExpr>> *pragma_attrs,
191-
bool is_block) {
211+
bool is_block,
212+
const Array<Buffer> &alloc_buffers = Array<Buffer>()) {
192213
Map<String, ffi::Any> preserved_annotations;
193214
pragma_attrs->clear();
194215
for (const auto &kv : annotations) {
195216
const String &key = kv.first;
196217
if (tir::attr::IsPragmaKey(key)) {
197218
pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
219+
} else if (key == tl::attr::kLocalVarInit) {
220+
if (auto local_init_map = kv.second.try_cast<Map<Var, PrimExpr>>()) {
221+
for (const auto &pair : local_init_map.value()) {
222+
local_var_init_map_.Set(pair.first, pair.second);
223+
}
224+
} else if (auto init_expr = kv.second.try_cast<PrimExpr>()) {
225+
ICHECK(is_block) << "`" << tl::attr::kLocalVarInit
226+
<< "` on non-block annotations is not supported";
227+
Buffer target = ResolveLocalVarBuffer(alloc_buffers);
228+
if (!target.defined()) {
229+
LOG(WARNING) << "Failed to resolve buffer for `"
230+
<< tl::attr::kLocalVarInit << "` annotation";
231+
continue;
232+
}
233+
local_var_init_map_.Set(target->data, init_expr.value());
234+
} else {
235+
LOG(FATAL) << "Expected `" << tl::attr::kLocalVarInit
236+
<< "` to be a PrimExpr or Map<Var, PrimExpr>, but got "
237+
<< kv.second.GetTypeKey();
238+
}
198239
} else if (!is_block) {
199240
// the loop annotation is preserved
200241
preserved_annotations.Set(key, kv.second);
@@ -206,6 +247,19 @@ class OpaqueBlockLower : public StmtExprMutator {
206247
return preserved_annotations;
207248
}
208249

250+
Buffer ResolveLocalVarBuffer(const Array<Buffer> &alloc_buffers) const {
251+
for (const Buffer &buffer : alloc_buffers) {
252+
std::string scope = buffer.scope();
253+
if (scope.find("local.var") != std::string::npos) {
254+
return buffer;
255+
}
256+
}
257+
if (!alloc_buffers.empty()) {
258+
return alloc_buffers.back();
259+
}
260+
return Buffer();
261+
}
262+
209263
/*! \brief Record the loop_var and loop start value of unit loops, whose
210264
* extent is one. */
211265
std::unordered_map<Var, PrimExpr> unit_loop_vars_;
@@ -215,12 +269,13 @@ class OpaqueBlockLower : public StmtExprMutator {
215269

216270
/*! \brief The map from buffer var to its storage alignment information. */
217271
std::unordered_map<Var, StorageAlignAnnotation> storage_align_;
272+
273+
/*! \brief Local var initializers collected from block annotations. */
274+
Map<Var, PrimExpr> local_var_init_map_;
218275
};
219276

220277
PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
221-
auto fptr = f.CopyOnWrite();
222-
fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
223-
return f;
278+
return OpaqueBlockLower::Rewrite(std::move(f));
224279
}
225280

226281
tir::transform::Pass LowerOpaqueBlock() {

src/transform/storage_rewrite.cc

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <tvm/arith/analyzer.h>
2626
#include <tvm/ffi/function.h>
2727
#include <tvm/ffi/reflection/registry.h>
28+
#include <tvm/ir/attrs.h>
2829
#include <tvm/ir/type.h>
2930
#include <tvm/target/target_info.h>
3031
#include <tvm/tir/analysis.h>
@@ -468,8 +469,10 @@ class StoragePlanRewriter : public StmtExprMutator {
468469
using AllocEntry = LinearAccessPatternFinder::AllocEntry;
469470

470471
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
471-
bool reuse_require_exact_matched_dtype) {
472+
bool reuse_require_exact_matched_dtype,
473+
Map<Var, PrimExpr> local_var_init_map = {}) {
472474
detect_inplace_ = detect_inplace;
475+
local_var_init_map_ = std::move(local_var_init_map);
473476
// plan the rewrite
474477
LinearAccessPatternFinder finder;
475478
finder(stmt);
@@ -694,6 +697,17 @@ class StoragePlanRewriter : public StmtExprMutator {
694697
}
695698
return body;
696699
}
700+
Map<String, ffi::Any> MakeAllocateAnnotations(const Var &buffer_var) const {
701+
Map<String, ffi::Any> annotations;
702+
if (local_var_init_map_.defined()) {
703+
auto it = local_var_init_map_.find(buffer_var);
704+
if (it != local_var_init_map_.end()) {
705+
const PrimExpr &init = (*it).second;
706+
annotations.Set(tl::attr::kLocalVarInit, init);
707+
}
708+
}
709+
return annotations;
710+
}
697711
// Remap the index
698712
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) {
699713
if (e->bits_offset == 0)
@@ -766,9 +780,11 @@ class StoragePlanRewriter : public StmtExprMutator {
766780

767781
if (all_allocs_identical) {
768782
// simply use the original allocation.
769-
e->alloc_nest.push_back(
770-
Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
771-
e->allocs[0]->condition, Evaluate(0)));
783+
Map<String, ffi::Any> annotations =
784+
MakeAllocateAnnotations(e->alloc_var);
785+
e->alloc_nest.push_back(Allocate(
786+
e->alloc_var, alloc_type, e->allocs[0]->extents,
787+
e->allocs[0]->condition, Evaluate(0), std::move(annotations)));
772788
if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
773789
e->alloc_nest.push_back(DeclBuffer(
774790
RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0)));
@@ -824,9 +840,11 @@ class StoragePlanRewriter : public StmtExprMutator {
824840
combo_size = combo_size + make_const(DataType::Int(32), 1);
825841
}
826842
combo_size = analyzer_.Simplify(combo_size);
827-
e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type,
828-
{combo_size}, const_true(),
829-
Evaluate(0)));
843+
Map<String, ffi::Any> annotations =
844+
MakeAllocateAnnotations(e->alloc_var);
845+
e->alloc_nest.push_back(
846+
Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(),
847+
Evaluate(0), std::move(annotations)));
830848
if (IsSpecialTaggedMemory(e->scope)) {
831849
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
832850
if (info.defined()) {
@@ -875,8 +893,10 @@ class StoragePlanRewriter : public StmtExprMutator {
875893
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
876894
PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
877895
(total_bits + type_bits - 1) / type_bits);
896+
Map<String, ffi::Any> annotations = MakeAllocateAnnotations(e->alloc_var);
878897
e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size},
879-
const_true(), Evaluate(0)));
898+
const_true(), Evaluate(0),
899+
std::move(annotations)));
880900
if (info.defined()) {
881901
ICHECK_LE(total_bits, info->max_num_bits)
882902
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
@@ -1178,6 +1198,8 @@ class StoragePlanRewriter : public StmtExprMutator {
11781198
// Any buffers that is accessed at some point. DeclBuffer instances
11791199
// that do not appear in this list may be removed.
11801200
std::unordered_set<const BufferNode *> all_buffers_accessed_;
1201+
// Initial values for local variable buffers.
1202+
Map<Var, PrimExpr> local_var_init_map_;
11811203
// analyzer
11821204
arith::Analyzer analyzer_;
11831205
};
@@ -1795,7 +1817,7 @@ class VectorTypeRewriter : public StmtExprMutator {
17951817
DLOG(INFO) << "Allocate with " << new_buffer_var << " and "
17961818
<< info.new_element_dtype << " extents: " << extents;
17971819
return Allocate(new_buffer_var, info.new_element_dtype, extents,
1798-
op->condition, op->body);
1820+
op->condition, op->body, op->annotations);
17991821
}
18001822

18011823
Stmt VisitStmt_(const AllocateConstNode *op) final {
@@ -1941,10 +1963,16 @@ Pass StorageRewrite() {
19411963
// Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU
19421964
reuse_require_exact_matched_dtype = true;
19431965
}
1966+
Map<Var, PrimExpr> local_var_init_map;
1967+
if (auto init_map =
1968+
f->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
1969+
local_var_init_map = init_map.value();
1970+
}
19441971
auto *n = f.CopyOnWrite();
1945-
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), detect_inplace,
1946-
enable_reuse,
1947-
reuse_require_exact_matched_dtype);
1972+
StoragePlanRewriter plan_rewriter;
1973+
n->body = plan_rewriter.Rewrite(
1974+
std::move(n->body), detect_inplace, enable_reuse,
1975+
reuse_require_exact_matched_dtype, std::move(local_var_init_map));
19481976
// Parameters may not be rewritten, but internal allocations may.
19491977
// Vectorization of AllocateConst is currently disabled, as it has
19501978
// indexing issues for types that include padding (e.g. int8x3

0 commit comments

Comments
 (0)