Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
Expand Down Expand Up @@ -348,12 +349,13 @@ TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
/*!
* \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit
*
* \param limit The limit to check.
* \param target The target whose VTCM limit should be used for any
* functions not already annotated with `tvm::attr::kTarget`.
*
* \returns The pass.
* \sa tvm::tir::CalculateAllocatedBytes
*/
TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
TVM_DLL Pass VerifyVTCMLimit(Optional<Target> target = NullOpt);

/*!
* \brief Statically check TIR code for out of bounds array access.
Expand Down
4 changes: 1 addition & 3 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1408,9 +1408,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
}
if (IsHexagonTask(task)) {
Target target = task->target;
const auto vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value().IntValue();
const auto& optimize =
tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)});
const auto& optimize = tir::transform::Sequential({tir::transform::VerifyVTCMLimit(target)});
optimize(mod);
}
const auto& optimize =
Expand Down
11 changes: 1 addition & 10 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,22 +544,13 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}

int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
if (target.defined() && target->kind->name == "hexagon") {
auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
if (value > 0) return value;
}
return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Array<Pass> mixed_pass_list;

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx)));
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

Expand Down
39 changes: 29 additions & 10 deletions src/tir/analysis/calculate_allocated_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,39 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
return true;
}

int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
if (target.defined() && target->kind->name == "hexagon") {
auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
if (value > 0) return value;
}
return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
}

namespace transform {

Pass VerifyVTCMLimit(const Integer& limit) {
Pass VerifyVTCMLimit(Optional<Target> default_target) {
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto func = kv.second.as<PrimFunc>()) {
auto sizes = CalculateAllocatedBytes(func.value());
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been "
"exceeded(allocated: "
<< vtcm_allocated << ", limit: " << limit << ").\n"
<< "In function\n"
<< func;
if (auto opt = kv.second.as<PrimFunc>()) {
auto func = opt.value();

std::optional<int64_t> limit = std::nullopt;
if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
limit = GetVTCMCapacity(func_target.value(), ctx);
} else if (default_target) {
limit = GetVTCMCapacity(default_target.value(), ctx);
}

if (limit.has_value() && limit.value() > 0) {
auto sizes = CalculateAllocatedBytes(func);
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (vtcm_allocated.IntValue() > limit.value()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded "
<< "(allocated: " << vtcm_allocated << ", limit: " << limit.value() << ").\n"
<< "In function\n"
<< func;
}
}
}
}
Expand Down