diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 4ac41d92c001..e8e65c6eccd4 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -296,6 +296,82 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { std::vector block_stack_; }; +/*! + * \brief Set the upper bound of the TIR variables that appear in + * the input function signature in the analyzer. + * \param func The function to be analyzed. + * \param ana The analyzer which contains the TIR var upper bounds. + */ +void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) { + // Use the attribute-annotated TIR var upper bounds as the TIR var values for + // memory planning. + // NOTE: we only apply the annotated upper bounds to the TIR variables that + // appear in the **function signature**. + Map var_upper_bound_attr_raw = + func->GetAttr>("tir_var_upper_bound") + .value_or(Map()); + std::unordered_map var_upper_bound_attr; + // We manually check the value type to ensure the values are all positive IntImm. + for (auto it : var_upper_bound_attr_raw) { + const auto* key = it.first.as(); + const auto* value = it.second.as(); + CHECK(key != nullptr) + << "The entry key of attr `tir_var_upper_bound` should be string. However " + << it.first->GetTypeKey() << " is got."; + CHECK(value != nullptr) + << "The entry value of attr `tir_var_upper_bound` should be integer. However " + << it.second->GetTypeKey() << " is got."; + CHECK_GT(value->value, 0) + << "The entry value of attr `tir_var_upper_bound` should be a positive integer, while " + << value->value << " is got."; + var_upper_bound_attr[GetRef(key)] = GetRef(value); + } + Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); + for (const tir::Var& tir_var : var_in_signature) { + auto it = var_upper_bound_attr.find(tir_var->name_hint); + if (it != var_upper_bound_attr.end()) { + ana->Bind(tir_var, + tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0), + tvm::IntImm(DataType::Int(64), (*it).second->value + 1))); + } + } +} + +/*! + * \brief Use the upper bounds of TIR vars to compute the upper + * bound of a given shape. + * \param shape The input shape to be computed. + * \param ana The arithmetic analyzer that contains the upper bounds + * of TIR variables + * \return The upper-bounded shape. When a dimension's upper bound + * cannot be determined, we keep the dimension unchanged. + */ +Array GetUpperBoundShape(Array shape, arith::Analyzer* ana) { + // Use the upper bounds of TIR vars as their values. + Array upper_bounded_shape; + upper_bounded_shape.reserve(shape.size()); + for (const PrimExpr& dim_len : shape) { + int64_t max_bound = ana->const_int_bound(dim_len)->max_value; + if (max_bound == std::numeric_limits::max()) { + upper_bounded_shape.push_back(dim_len); + } else { + upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound)); + } + } + return upper_bounded_shape; +} + +/*! \brief Check if a shape is static (a.k.a., has no TIR variable). */ +bool IsStaticShape(Array shape) { + for (const PrimExpr& dim : shape) { + const auto* int_len = dim.as(); + if (!int_len) { + return false; + } + } + return true; +} + /*! * \brief The visitor class for storage token initialization. * \details It goes through the entire function to get the storage tokens @@ -330,40 +406,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {} void VisitExpr_(const FunctionNode* func) final { - // Use the attribute-annotated TIR var upper bounds as the TIR var values for - // memory planning. - // NOTE: we only apply the annotated upper bounds to the TIR variables that - // appear in the **function signature**. - Map var_upper_bound_attr_raw = - func->GetAttr>("tir_var_upper_bound") - .value_or(Map()); - std::unordered_map var_upper_bound_attr; - // We manually check the value type to ensure the values are all positive IntImm. - for (auto it : var_upper_bound_attr_raw) { - const auto* key = it.first.as(); - const auto* value = it.second.as(); - CHECK(key != nullptr) - << "The entry key of attr `tir_var_upper_bound` should be string. However " - << it.first->GetTypeKey() << " is got."; - CHECK(value != nullptr) - << "The entry value of attr `tir_var_upper_bound` should be integer. However " - << it.second->GetTypeKey() << " is got."; - CHECK_GT(value->value, 0) - << "The entry value of attr `tir_var_upper_bound` should be a positive integer, while " - << value->value << " is got."; - var_upper_bound_attr[GetRef(key)] = GetRef(value); - } - Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(GetRef(func))); - var_upper_bound_.clear(); - for (const tir::Var& tir_var : var_in_signature) { - auto it = var_upper_bound_attr.find(tir_var->name_hint); - if (it != var_upper_bound_attr.end()) { - ana_.Bind(tir_var, tvm::Range::FromMinExtent( - tvm::IntImm(DataType::Int(64), 0), - tvm::IntImm(DataType::Int(64), (*it).second->value + 1))); - } - } - + // Set the upper bound of TIR variables in the analyzer. + SetTIRVarUpperBound(GetRef(func), &ana_); // Recurse into the function to get its tokens. Tokens body_tokens = GetTokens(func->body); // Discard the tokens used by the function return value, as they are external referenced. @@ -457,32 +501,20 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // - the tensor has known dtype; // - no storage token was created for this call before. const auto* sinfo = call->struct_info_.as(); - const auto* shape = sinfo->shape.as(); ICHECK_NOTNULL(sinfo); + const auto* shape = sinfo->shape.as(); ICHECK_NOTNULL(shape); ICHECK(!sinfo->IsUnknownDtype()); ICHECK(sinfo->dtype == Downcast(call->args[1])->value); ICHECK(!token_map_.count(call)); // Use the upper bounds of TIR vars as their values. - Array upper_bounded_shape; - upper_bounded_shape.reserve(shape->values.size()); - for (const PrimExpr& dim_len : shape->values) { - int64_t max_bound = ana_.const_int_bound(dim_len)->max_value; - if (max_bound == std::numeric_limits::max()) { - upper_bounded_shape.push_back(dim_len); - } else { - upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound)); - } - } + Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_); // No support for TIR vars that are not bounded. - for (const PrimExpr& dim_len : upper_bounded_shape) { - const auto* int_len = dim_len.as(); - if (!int_len) { - token_map_[call] = Tokens(); - return Tokens(); - } + if (!IsStaticShape(upper_bounded_shape)) { + token_map_[call] = Tokens(); + return Tokens(); } // Create and set token. @@ -558,8 +590,6 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { * a PrimFunc inside the IRModule. */ const IRModule& ctx_mod_; - /*! \brief The mapping from TIR variables to their respective upper bound values. */ - std::unordered_map var_upper_bound_; /*! \brief The mapping from each token to the binding block where it is created. */ std::unordered_map token2block_; /*! \brief The mapping from each token to the Exprs that are using this token. */ @@ -729,8 +759,17 @@ class StorageAllocationRewriter : public ExprMutator { if (func_ == nullptr) { continue; } + constexpr static const char* plan_dyn_attr_ = "relax.memory_plan_dynamic_func_output"; + plan_dynamic_output_ = static_cast( + func_->GetAttr(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value); + if (plan_dynamic_output_) { + SetTIRVarUpperBound(GetRef(func_), &ana_); + } token2storage_var_.clear(); Function func = Downcast(this->VisitExpr_(func_)); + if (plan_dynamic_output_) { + func = WithoutAttr(func, plan_dyn_attr_); + } builder_->UpdateFunction(gv, func); } return builder_->GetContextIRModule(); @@ -740,8 +779,13 @@ class StorageAllocationRewriter : public ExprMutator { using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call) final { + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage"); + static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor"); auto it = alloc_tensor2token_.find(call); if (it != alloc_tensor2token_.end()) { + // Case 1. This `alloc_tensor` is planned for memory reuse. + ICHECK_EQ(call->op, alloc_tensor_op); const auto* sinfo = call->struct_info_.as(); ICHECK_NOTNULL(sinfo); ICHECK_NOTNULL(sinfo->shape.as()); @@ -753,7 +797,6 @@ class StorageAllocationRewriter : public ExprMutator { Var storage_var{nullptr}; auto it_token = token2storage_var_.find(token.get()); if (it_token == token2storage_var_.end()) { - static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage"); ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)}); PrimValue virtual_device_index = runtime_device_index; std::string storage_scope = "global"; @@ -769,16 +812,46 @@ class StorageAllocationRewriter : public ExprMutator { } // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. - static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor"); PrimValue offset = PrimValue::Int64(0); DataType dtype = sinfo->dtype; return Call(mem_alloc_tensor, {storage_var, offset, sinfo->shape.value(), DataTypeImm(dtype)}, Attrs()); + } else if (plan_dynamic_output_ && call->op == alloc_tensor_op) { + // Case 2. For a `alloc_tensor` that is not planned for memory reuse, + // we would still like to allocate **static** memory for the tensor. + // So in case the tensor shape is dynamic but has an upper bound + // estimation, we allocate a storage to its upper bound size, and + // allocate a tensor out from it with the actual symbolic shape. + + const auto* sinfo = call->struct_info_.as(); + ICHECK_NOTNULL(sinfo); + const auto* shape = sinfo->shape.as(); + ICHECK_NOTNULL(shape); + Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_); + if (!IsStaticShape(shape->values) && IsStaticShape(upper_bounded_shape)) { + ICHECK(!sinfo->IsUnknownDtype()); + ICHECK_EQ(sinfo->dtype, Downcast(call->args[1])->value); + StorageToken token(upper_bounded_shape, sinfo->dtype); + Call alloc_storage(mem_alloc_storage, + {/*size=*/ShapeExpr({tvm::IntImm(DataType::Int(64), token->bytes)}), + /*virtual_device_index=*/Downcast(call->args[2]), + /*storage_scope=*/StringImm("global"), // + /*dtype=*/DataTypeImm(token->dtype)}); + Var storage = builder_->Emit(alloc_storage, "storage"); + return Call(mem_alloc_tensor, {storage, // + /*offset=*/PrimValue::Int64(0), + /*shape=*/GetRef(shape), // + /*dtype=*/DataTypeImm(sinfo->dtype)}); + } } return ExprMutator::VisitExpr_(call); } + /*! \brief The arithmetic analyzer. */ + arith::Analyzer ana_; + /*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */ + bool plan_dynamic_output_; /*! * \brief The mapping from each memory-reusable `builtin.alloc_tensor` to its corresponding underlying storage token that it is using. diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 9d1fe4fd40a4..783f18ee9806 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1109,6 +1109,68 @@ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"): tvm.ir.assert_structural_equal(mod, Expected) +def test_call_tir_dyn_plan_dynamic_func_output(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def tir_full(var_full: T.handle, n: T.int64): + T.evaluate(0) + + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"): + n = T.int64() + R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True, "relax.memory_plan_dynamic_func_output": True}) + cls = Module + alloc: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0)) + _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n]))) + full: R.Tensor((n,), dtype="float32") = alloc + alloc1: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0)) + _1: R.Tuple = cls.tir_exp(full, alloc1) + lv2: R.Tensor((n,), dtype="float32") = alloc1 + alloc2: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0)) + _2: R.Tuple = cls.tir_exp(lv2, alloc2) + lv3: R.Tensor((n,), dtype="float32") = alloc2 + return lv3 + + @I.ir_module + class Expected: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @T.prim_func + def tir_full(var_full: T.handle, n: T.int64): + T.evaluate(0) + + @R.function + def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"): + n = T.int64() + R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True}) + cls = Expected + storage: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), R.dtype("float32")) + _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n]))) + full: R.Tensor((n,), dtype="float32") = alloc + storage1: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc1: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n]), R.dtype("float32")) + _1: R.Tuple = cls.tir_exp(full, alloc1) + lv2: R.Tensor((n,), dtype="float32") = alloc1 + storage2: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc2: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n]), R.dtype("float32")) + _2: R.Tuple = cls.tir_exp(lv2, alloc2) + lv3: R.Tensor((n,), dtype="float32") = alloc2 + return lv3 + # fmt: on + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_function_independence(): # fmt: off @tvm.script.ir_module