|  | 
| 25 | 25 | #include <tvm/arith/analyzer.h> | 
| 26 | 26 | #include <tvm/ffi/function.h> | 
| 27 | 27 | #include <tvm/ffi/reflection/registry.h> | 
|  | 28 | +#include <tvm/ir/attrs.h> | 
| 28 | 29 | #include <tvm/ir/type.h> | 
| 29 | 30 | #include <tvm/target/target_info.h> | 
| 30 | 31 | #include <tvm/tir/analysis.h> | 
| @@ -468,8 +469,10 @@ class StoragePlanRewriter : public StmtExprMutator { | 
| 468 | 469 |   using AllocEntry = LinearAccessPatternFinder::AllocEntry; | 
| 469 | 470 | 
 | 
| 470 | 471 |   Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse, | 
| 471 |  | -               bool reuse_require_exact_matched_dtype) { | 
|  | 472 | +               bool reuse_require_exact_matched_dtype, | 
|  | 473 | +               Map<Var, PrimExpr> local_var_init_map = {}) { | 
| 472 | 474 |     detect_inplace_ = detect_inplace; | 
|  | 475 | +    local_var_init_map_ = std::move(local_var_init_map); | 
| 473 | 476 |     // plan the rewrite | 
| 474 | 477 |     LinearAccessPatternFinder finder; | 
| 475 | 478 |     finder(stmt); | 
| @@ -694,6 +697,17 @@ class StoragePlanRewriter : public StmtExprMutator { | 
| 694 | 697 |     } | 
| 695 | 698 |     return body; | 
| 696 | 699 |   } | 
|  | 700 | +  Map<String, ffi::Any> MakeAllocateAnnotations(const Var &buffer_var) const { | 
|  | 701 | +    Map<String, ffi::Any> annotations; | 
|  | 702 | +    if (local_var_init_map_.defined()) { | 
|  | 703 | +      auto it = local_var_init_map_.find(buffer_var); | 
|  | 704 | +      if (it != local_var_init_map_.end()) { | 
|  | 705 | +        const PrimExpr &init = (*it).second; | 
|  | 706 | +        annotations.Set(tl::attr::kLocalVarInit, init); | 
|  | 707 | +      } | 
|  | 708 | +    } | 
|  | 709 | +    return annotations; | 
|  | 710 | +  } | 
| 697 | 711 |   // Remap the index | 
| 698 | 712 |   PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) { | 
| 699 | 713 |     if (e->bits_offset == 0) | 
| @@ -766,9 +780,11 @@ class StoragePlanRewriter : public StmtExprMutator { | 
| 766 | 780 | 
 | 
| 767 | 781 |         if (all_allocs_identical) { | 
| 768 | 782 |           // simply use the original allocation. | 
| 769 |  | -          e->alloc_nest.push_back( | 
| 770 |  | -              Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents, | 
| 771 |  | -                       e->allocs[0]->condition, Evaluate(0))); | 
|  | 783 | +          Map<String, ffi::Any> annotations = | 
|  | 784 | +              MakeAllocateAnnotations(e->alloc_var); | 
|  | 785 | +          e->alloc_nest.push_back(Allocate( | 
|  | 786 | +              e->alloc_var, alloc_type, e->allocs[0]->extents, | 
|  | 787 | +              e->allocs[0]->condition, Evaluate(0), std::move(annotations))); | 
| 772 | 788 |           if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) { | 
| 773 | 789 |             e->alloc_nest.push_back(DeclBuffer( | 
| 774 | 790 |                 RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0))); | 
| @@ -824,9 +840,11 @@ class StoragePlanRewriter : public StmtExprMutator { | 
| 824 | 840 |             combo_size = combo_size + make_const(DataType::Int(32), 1); | 
| 825 | 841 |           } | 
| 826 | 842 |           combo_size = analyzer_.Simplify(combo_size); | 
| 827 |  | -          e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type, | 
| 828 |  | -                                           {combo_size}, const_true(), | 
| 829 |  | -                                           Evaluate(0))); | 
|  | 843 | +          Map<String, ffi::Any> annotations = | 
|  | 844 | +              MakeAllocateAnnotations(e->alloc_var); | 
|  | 845 | +          e->alloc_nest.push_back( | 
|  | 846 | +              Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), | 
|  | 847 | +                       Evaluate(0), std::move(annotations))); | 
| 830 | 848 |           if (IsSpecialTaggedMemory(e->scope)) { | 
| 831 | 849 |             MemoryInfo info = GetMemoryInfo(e->scope.to_string()); | 
| 832 | 850 |             if (info.defined()) { | 
| @@ -875,8 +893,10 @@ class StoragePlanRewriter : public StmtExprMutator { | 
| 875 | 893 |     uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); | 
| 876 | 894 |     PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), | 
| 877 | 895 |                                      (total_bits + type_bits - 1) / type_bits); | 
|  | 896 | +    Map<String, ffi::Any> annotations = MakeAllocateAnnotations(e->alloc_var); | 
| 878 | 897 |     e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size}, | 
| 879 |  | -                                     const_true(), Evaluate(0))); | 
|  | 898 | +                                     const_true(), Evaluate(0), | 
|  | 899 | +                                     std::move(annotations))); | 
| 880 | 900 |     if (info.defined()) { | 
| 881 | 901 |       ICHECK_LE(total_bits, info->max_num_bits) | 
| 882 | 902 |           << "Allocation exceed bound of memory tag " << e->scope.to_string(); | 
| @@ -1178,6 +1198,8 @@ class StoragePlanRewriter : public StmtExprMutator { | 
| 1178 | 1198 |   // Any buffers that is accessed at some point.  DeclBuffer instances | 
| 1179 | 1199 |   // that do not appear in this list may be removed. | 
| 1180 | 1200 |   std::unordered_set<const BufferNode *> all_buffers_accessed_; | 
|  | 1201 | +  // Initial values for local variable buffers. | 
|  | 1202 | +  Map<Var, PrimExpr> local_var_init_map_; | 
| 1181 | 1203 |   // analyzer | 
| 1182 | 1204 |   arith::Analyzer analyzer_; | 
| 1183 | 1205 | }; | 
| @@ -1795,7 +1817,7 @@ class VectorTypeRewriter : public StmtExprMutator { | 
| 1795 | 1817 |     DLOG(INFO) << "Allocate with " << new_buffer_var << " and " | 
| 1796 | 1818 |                << info.new_element_dtype << " extents: " << extents; | 
| 1797 | 1819 |     return Allocate(new_buffer_var, info.new_element_dtype, extents, | 
| 1798 |  | -                    op->condition, op->body); | 
|  | 1820 | +                    op->condition, op->body, op->annotations); | 
| 1799 | 1821 |   } | 
| 1800 | 1822 | 
 | 
| 1801 | 1823 |   Stmt VisitStmt_(const AllocateConstNode *op) final { | 
| @@ -1941,10 +1963,16 @@ Pass StorageRewrite() { | 
| 1941 | 1963 |       // Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU | 
| 1942 | 1964 |       reuse_require_exact_matched_dtype = true; | 
| 1943 | 1965 |     } | 
|  | 1966 | +    Map<Var, PrimExpr> local_var_init_map; | 
|  | 1967 | +    if (auto init_map = | 
|  | 1968 | +            f->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) { | 
|  | 1969 | +      local_var_init_map = init_map.value(); | 
|  | 1970 | +    } | 
| 1944 | 1971 |     auto *n = f.CopyOnWrite(); | 
| 1945 |  | -    n->body = StoragePlanRewriter().Rewrite(std::move(n->body), detect_inplace, | 
| 1946 |  | -                                            enable_reuse, | 
| 1947 |  | -                                            reuse_require_exact_matched_dtype); | 
|  | 1972 | +    StoragePlanRewriter plan_rewriter; | 
|  | 1973 | +    n->body = plan_rewriter.Rewrite( | 
|  | 1974 | +        std::move(n->body), detect_inplace, enable_reuse, | 
|  | 1975 | +        reuse_require_exact_matched_dtype, std::move(local_var_init_map)); | 
| 1948 | 1976 |     // Parameters may not be rewritten, but internal allocations may. | 
| 1949 | 1977 |     // Vectorization of AllocateConst is currently disabled, as it has | 
| 1950 | 1978 |     // indexing issues for types that include padding (e.g. int8x3 | 
|  | 
0 commit comments