Skip to content

Commit

Permalink
[Unity][Transform] Memory planning for dynamic-shape func return (apa…
Browse files Browse the repository at this point in the history
…che#16111)

This PR enhances the static block memory planning pass.
Prior to this PR, the memory planning only works on memory
allocation that is not externally referenced. In dynamic
shape settings, such memory allocation is not fully static
and may lead to memory fragmentation.

This PR enhances the behavior, so that for such memory
allocation, we first allocate a storage with regard to its
estimated upper bound (when known), and then allocate the
tensor with the actual dynamic shape out from the storage.
This will ensure the static memory allocation and avoid
memory fragmentation.
  • Loading branch information
MasterJH5574 authored Jan 15, 2024
1 parent 98d5153 commit cf14edd
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 55 deletions.
183 changes: 128 additions & 55 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,82 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
std::vector<const BindingBlockNode*> block_stack_;
};

/*!
* \brief Set the upper bound of the TIR variables that appear in
* the input function signature in the analyzer.
* \param func The function to be analyzed.
* \param ana The analyzer which contains the TIR var upper bounds.
*/
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
// memory planning.
// NOTE: we only apply the annotated upper bounds to the TIR variables that
// appear in the **function signature**.
Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
.value_or(Map<ObjectRef, ObjectRef>());
std::unordered_map<String, IntImm> var_upper_bound_attr;
// We manually check the value type to ensure the values are all positive IntImm.
for (auto it : var_upper_bound_attr_raw) {
const auto* key = it.first.as<StringObj>();
const auto* value = it.second.as<IntImmNode>();
CHECK(key != nullptr)
<< "The entry key of attr `tir_var_upper_bound` should be string. However "
<< it.first->GetTypeKey() << " is got.";
CHECK(value != nullptr)
<< "The entry value of attr `tir_var_upper_bound` should be integer. However "
<< it.second->GetTypeKey() << " is got.";
CHECK_GT(value->value, 0)
<< "The entry value of attr `tir_var_upper_bound` should be a positive integer, while "
<< value->value << " is got.";
var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
}
Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
for (const tir::Var& tir_var : var_in_signature) {
auto it = var_upper_bound_attr.find(tir_var->name_hint);
if (it != var_upper_bound_attr.end()) {
ana->Bind(tir_var,
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
tvm::IntImm(DataType::Int(64), (*it).second->value + 1)));
}
}
}

/*!
* \brief Use the upper bounds of TIR vars to compute the upper
* bound of a given shape.
* \param shape The input shape to be computed.
* \param ana The arithmetic analyzer that contains the upper bounds
* of TIR variables
* \return The upper-bounded shape. When a dimension's upper bound
* cannot be determined, we keep the dimension unchanged.
*/
Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer* ana) {
// Use the upper bounds of TIR vars as their values.
Array<PrimExpr> upper_bounded_shape;
upper_bounded_shape.reserve(shape.size());
for (const PrimExpr& dim_len : shape) {
int64_t max_bound = ana->const_int_bound(dim_len)->max_value;
if (max_bound == std::numeric_limits<int64_t>::max()) {
upper_bounded_shape.push_back(dim_len);
} else {
upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
}
}
return upper_bounded_shape;
}

/*! \brief Check if a shape is static (a.k.a., has no TIR variable). */
bool IsStaticShape(Array<PrimExpr> shape) {
for (const PrimExpr& dim : shape) {
const auto* int_len = dim.as<IntImmNode>();
if (!int_len) {
return false;
}
}
return true;
}

/*!
* \brief The visitor class for storage token initialization.
* \details It goes through the entire function to get the storage tokens
Expand Down Expand Up @@ -330,40 +406,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}

void VisitExpr_(const FunctionNode* func) final {
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
// memory planning.
// NOTE: we only apply the annotated upper bounds to the TIR variables that
// appear in the **function signature**.
Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
.value_or(Map<ObjectRef, ObjectRef>());
std::unordered_map<String, IntImm> var_upper_bound_attr;
// We manually check the value type to ensure the values are all positive IntImm.
for (auto it : var_upper_bound_attr_raw) {
const auto* key = it.first.as<StringObj>();
const auto* value = it.second.as<IntImmNode>();
CHECK(key != nullptr)
<< "The entry key of attr `tir_var_upper_bound` should be string. However "
<< it.first->GetTypeKey() << " is got.";
CHECK(value != nullptr)
<< "The entry value of attr `tir_var_upper_bound` should be integer. However "
<< it.second->GetTypeKey() << " is got.";
CHECK_GT(value->value, 0)
<< "The entry value of attr `tir_var_upper_bound` should be a positive integer, while "
<< value->value << " is got.";
var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
}
Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(GetRef<Function>(func)));
var_upper_bound_.clear();
for (const tir::Var& tir_var : var_in_signature) {
auto it = var_upper_bound_attr.find(tir_var->name_hint);
if (it != var_upper_bound_attr.end()) {
ana_.Bind(tir_var, tvm::Range::FromMinExtent(
tvm::IntImm(DataType::Int(64), 0),
tvm::IntImm(DataType::Int(64), (*it).second->value + 1)));
}
}

// Set the upper bound of TIR variables in the analyzer.
SetTIRVarUpperBound(GetRef<Function>(func), &ana_);
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are external referenced.
Expand Down Expand Up @@ -457,32 +501,20 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
// - the tensor has known dtype;
// - no storage token was created for this call before.
const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
const auto* shape = sinfo->shape.as<ShapeExprNode>();
ICHECK_NOTNULL(sinfo);
const auto* shape = sinfo->shape.as<ShapeExprNode>();
ICHECK_NOTNULL(shape);
ICHECK(!sinfo->IsUnknownDtype());
ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
ICHECK(!token_map_.count(call));

// Use the upper bounds of TIR vars as their values.
Array<PrimExpr> upper_bounded_shape;
upper_bounded_shape.reserve(shape->values.size());
for (const PrimExpr& dim_len : shape->values) {
int64_t max_bound = ana_.const_int_bound(dim_len)->max_value;
if (max_bound == std::numeric_limits<int64_t>::max()) {
upper_bounded_shape.push_back(dim_len);
} else {
upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
}
}
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);

// No support for TIR vars that are not bounded.
for (const PrimExpr& dim_len : upper_bounded_shape) {
const auto* int_len = dim_len.as<IntImmNode>();
if (!int_len) {
token_map_[call] = Tokens();
return Tokens();
}
if (!IsStaticShape(upper_bounded_shape)) {
token_map_[call] = Tokens();
return Tokens();
}

// Create and set token.
Expand Down Expand Up @@ -558,8 +590,6 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
* a PrimFunc inside the IRModule.
*/
const IRModule& ctx_mod_;
/*! \brief The mapping from TIR variables to their respective upper bound values. */
std::unordered_map<tir::Var, IntImm, ObjectPtrHash, ObjectPtrEqual> var_upper_bound_;
/*! \brief The mapping from each token to the binding block where it is created. */
std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
/*! \brief The mapping from each token to the Exprs that are using this token. */
Expand Down Expand Up @@ -729,8 +759,17 @@ class StorageAllocationRewriter : public ExprMutator {
if (func_ == nullptr) {
continue;
}
constexpr static const char* plan_dyn_attr_ = "relax.memory_plan_dynamic_func_output";
plan_dynamic_output_ = static_cast<bool>(
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value);
if (plan_dynamic_output_) {
SetTIRVarUpperBound(GetRef<Function>(func_), &ana_);
}
token2storage_var_.clear();
Function func = Downcast<Function>(this->VisitExpr_(func_));
if (plan_dynamic_output_) {
func = WithoutAttr(func, plan_dyn_attr_);
}
builder_->UpdateFunction(gv, func);
}
return builder_->GetContextIRModule();
Expand All @@ -740,8 +779,13 @@ class StorageAllocationRewriter : public ExprMutator {
using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call) final {
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
auto it = alloc_tensor2token_.find(call);
if (it != alloc_tensor2token_.end()) {
// Case 1. This `alloc_tensor` is planned for memory reuse.
ICHECK_EQ(call->op, alloc_tensor_op);
const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
ICHECK_NOTNULL(sinfo);
ICHECK_NOTNULL(sinfo->shape.as<ShapeExprNode>());
Expand All @@ -753,7 +797,6 @@ class StorageAllocationRewriter : public ExprMutator {
Var storage_var{nullptr};
auto it_token = token2storage_var_.find(token.get());
if (it_token == token2storage_var_.end()) {
static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
PrimValue virtual_device_index = runtime_device_index;
std::string storage_scope = "global";
Expand All @@ -769,16 +812,46 @@ class StorageAllocationRewriter : public ExprMutator {
}

// And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`.
static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
PrimValue offset = PrimValue::Int64(0);
DataType dtype = sinfo->dtype;
return Call(mem_alloc_tensor, {storage_var, offset, sinfo->shape.value(), DataTypeImm(dtype)},
Attrs());
} else if (plan_dynamic_output_ && call->op == alloc_tensor_op) {
// Case 2. For a `alloc_tensor` that is not planned for memory reuse,
// we would still like to allocate **static** memory for the tensor.
// So in case the tensor shape is dynamic but has an upper bound
// estimation, we allocate a storage to its upper bound size, and
// allocate a tensor out from it with the actual symbolic shape.

const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
ICHECK_NOTNULL(sinfo);
const auto* shape = sinfo->shape.as<ShapeExprNode>();
ICHECK_NOTNULL(shape);
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);
if (!IsStaticShape(shape->values) && IsStaticShape(upper_bounded_shape)) {
ICHECK(!sinfo->IsUnknownDtype());
ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
StorageToken token(upper_bounded_shape, sinfo->dtype);
Call alloc_storage(mem_alloc_storage,
{/*size=*/ShapeExpr({tvm::IntImm(DataType::Int(64), token->bytes)}),
/*virtual_device_index=*/Downcast<PrimValue>(call->args[2]),
/*storage_scope=*/StringImm("global"), //
/*dtype=*/DataTypeImm(token->dtype)});
Var storage = builder_->Emit(alloc_storage, "storage");
return Call(mem_alloc_tensor, {storage, //
/*offset=*/PrimValue::Int64(0),
/*shape=*/GetRef<ShapeExpr>(shape), //
/*dtype=*/DataTypeImm(sinfo->dtype)});
}
}

return ExprMutator::VisitExpr_(call);
}

/*! \brief The arithmetic analyzer. */
arith::Analyzer ana_;
/*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */
bool plan_dynamic_output_;
/*!
* \brief The mapping from each memory-reusable `builtin.alloc_tensor` to
its corresponding underlying storage token that it is using.
Expand Down
62 changes: 62 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,68 @@ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
tvm.ir.assert_structural_equal(mod, Expected)


def test_call_tir_dyn_plan_dynamic_func_output():
# fmt: off
@I.ir_module
class Module:
@T.prim_func
def tir_full(var_full: T.handle, n: T.int64):
T.evaluate(0)

@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@R.function
def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
n = T.int64()
R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True, "relax.memory_plan_dynamic_func_output": True})
cls = Module
alloc: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
full: R.Tensor((n,), dtype="float32") = alloc
alloc1: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
_1: R.Tuple = cls.tir_exp(full, alloc1)
lv2: R.Tensor((n,), dtype="float32") = alloc1
alloc2: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n,), dtype="float32") = alloc2
return lv3

@I.ir_module
class Expected:
@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@T.prim_func
def tir_full(var_full: T.handle, n: T.int64):
T.evaluate(0)

@R.function
def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
n = T.int64()
R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True})
cls = Expected
storage: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), R.dtype("float32"))
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
full: R.Tensor((n,), dtype="float32") = alloc
storage1: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc1: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n]), R.dtype("float32"))
_1: R.Tuple = cls.tir_exp(full, alloc1)
lv2: R.Tensor((n,), dtype="float32") = alloc1
storage2: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc2: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n]), R.dtype("float32"))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n,), dtype="float32") = alloc2
return lv3
# fmt: on

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)


def test_function_independence():
# fmt: off
@tvm.script.ir_module
Expand Down

0 comments on commit cf14edd

Please sign in to comment.