Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[AOT] Add CreateFunctionMetadata analysis pass (apache#13095)
Browse files Browse the repository at this point in the history
AOT requires FunctionInfo to be defined for all the functions
in the module. This stores information on how much memory the
functions use. This commit adds a separate analysis pass to
create all the FunctionInfos + some tests for the new pass.
  • Loading branch information
mbaret authored and xinetzone committed Nov 25, 2022
1 parent e4a60c7 commit 0b54cd2
Show file tree
Hide file tree
Showing 7 changed files with 564 additions and 3 deletions.
7 changes: 7 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
*/
TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);

/*!
* \brief Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc
* \param func The TIR PrimFunc for which the constants size to be calculated
* \param constant_byte_alignment The byte alignment required for each constant allocated
*/
TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment);

/*!
* \brief Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc
* \param func The TIR PrimFunc for which the workspace size to be calculated
Expand Down
52 changes: 52 additions & 0 deletions python/tvm/ir/memory_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.runtime import NDArray
from . import _ffi_api


Expand Down Expand Up @@ -101,6 +102,34 @@ def __init__(
)


@register_object("ir.ConstantInfo")
class ConstantInfo(Object):
"""ConstantInfo object hold information on a constant pool.
Parameters
----------
name_hint : str
Name of the constant.
byte_offset : int
The byte_offset of the constant.
data : NDArray
The data of the constant.
"""

def __init__(
self,
name_hint: str,
byte_offset: int,
data: NDArray,
):
self.__init_handle_by_constructor__(
_ffi_api.ConstantInfo, # type: ignore # pylint: disable=no-member
name_hint,
byte_offset,
data,
)


@register_object("ir.WorkspacePoolInfo")
class WorkspacePoolInfo(PoolInfo):
"""WorkspacePoolInfo object holds information related to RW memory pools
Expand Down Expand Up @@ -214,3 +243,26 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.ConstantMemoryPools, pools # type: ignore # pylint: disable=no-member
)


@register_object("ir.ConstantMemoryPools")
class AllocatedPoolInfo(Object):
"""Allocate memory in a given pool.
Parameters
----------
pool : PoolInfo
The pool in which to allocate memory.
allocated_size : int
The size of memory to allocate.
"""

def __init__(
self,
pool: PoolInfo,
allocated_size: int,
pool_var_idx: int = 0,
):
self.__init_handle_by_constructor__(
_ffi_api.AllocatedPoolInfo, pool, allocated_size, pool_var_idx # type: ignore # pylint: disable=no-member
)
26 changes: 26 additions & 0 deletions python/tvm/relay/backend/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.
# pylint: disable=invalid-name
"""AOT passes"""
from typing import Dict

from tvm import IRModule
from tvm.ir.transform import Pass
from .utils import CallType

Expand All @@ -41,3 +44,26 @@ def AOTLowerMain(mod_name: str, config: object, call_type: CallType) -> Pass:
"""
return _aot.AOTLowerMain(mod_name, config, call_type.value)


def CreateFunctionMetadata(
mod: IRModule, workspace_byte_alignment: int, constant_byte_alignment: int
) -> Dict[str, object]:
"""Create the function metadata (FunctionInfos) from an AOT module.
Parameters
----------
mod : IRModule
The IRModule.
workspace_byte_alignment : int
The alignment of the workspace buffer in bytes.
constant_byte_alignment : int
The alignment of the constant buffer in bytes.
Returns
-------
Dict[str, FunctionInfo]
A map between function names and FunctionInfos.
"""
return _aot.CreateFunctionMetadata(mod, workspace_byte_alignment, constant_byte_alignment)
125 changes: 125 additions & 0 deletions src/relay/backend/aot/create_function_metadata.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/backend/aot/create_function_metadata.cc
* \brief Create FunctionInfo metadata from a lowered TIR module.
*/
#include "./create_function_metadata.h"

#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target_kind.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/usmp/utils.h>

#include "../utils.h"

namespace tvm {
namespace relay {
namespace backend {
namespace aot {

/*!
* \brief Calculate FunctionInfo for all the PrimFuncs in a module.
*/
Map<String, backend::FunctionInfo> CalculateFunctionInfos(const IRModule& mod,
Integer workspace_byte_alignment,
Integer constant_byte_alignment) {
Map<String, backend::FunctionInfo> function_metadata;
for (const auto& kv : mod->functions) {
GlobalVar global_var = kv.first;
BaseFunc base_func = kv.second;
if (base_func->IsInstance<tir::PrimFuncNode>()) {
tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
Optional<Target> tgt_opt = pfunc->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(tgt_opt) << "Target must be defined for all primfuncs.";
Target tgt = tgt_opt.value();
// Determine the size of input/output buffers
auto params = pfunc->params;
int64_t total_io_bytes = 0;
for (const auto& param : params) {
// Inputs/outputs will be handles, workspaces are pointers
if (param->dtype.is_handle()) {
auto buffer = pfunc->buffer_map[param];
total_io_bytes += GetMemorySizeBytes(buffer->shape, buffer->dtype);
}
}
const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
const auto& cs = CalculateConstantBytes(pfunc, constant_byte_alignment);
backend::FunctionInfo finfo{
{{tgt, ws}}, {{tgt, total_io_bytes}}, {{tgt, cs}}, {{tgt, pfunc}}, {}};
function_metadata.Set(global_var->name_hint, finfo);
}
}
return function_metadata;
}

Map<String, backend::FunctionInfo> CreateFunctionMetadata(const IRModule& mod,
Integer workspace_byte_alignment,
Integer constant_byte_alignment) {
// First calculate the FunctionInfos from the buffers that are explicitly allocated
auto function_metadata =
CalculateFunctionInfos(mod, workspace_byte_alignment, constant_byte_alignment);
// Now adjust the FunctionInfo for the main func to also include PoolInfo allocations
// made by the USMP.
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
backend::FunctionInfo main_func_info =
function_metadata.Get(runtime::symbol::tvm_module_main).value();
if (allocated_pool_infos) {
for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) {
for (const auto& tgt : allocated_pool_info->pool_info->targets) {
VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size "
<< allocated_pool_info->allocated_size->value;
size_t size = allocated_pool_info->allocated_size->value;
if (allocated_pool_info->pool_info->IsInstance<ConstantPoolInfoNode>()) {
size += main_func_info->constant_sizes.count(tgt)
? main_func_info->constant_sizes[tgt]->value
: 0;
main_func_info->constant_sizes.Set(tgt, size);
} else if (allocated_pool_info->pool_info->IsInstance<WorkspacePoolInfoNode>()) {
size += main_func_info->workspace_sizes.count(tgt)
? main_func_info->workspace_sizes[tgt]->value
: 0;
main_func_info->workspace_sizes.Set(tgt, size);
} else {
LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey();
}
}
}
}
function_metadata.Set(runtime::symbol::tvm_module_main, main_func_info);
return function_metadata;
}

TVM_REGISTER_GLOBAL("relay.backend.aot.CreateFunctionMetadata")
.set_body_typed(CreateFunctionMetadata);

} // namespace aot
} // namespace backend
} // namespace relay
} // namespace tvm
49 changes: 49 additions & 0 deletions src/relay/backend/aot/create_function_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RELAY_BACKEND_AOT_CREATE_FUNCTION_METADATA_H_
#define TVM_RELAY_BACKEND_AOT_CREATE_FUNCTION_METADATA_H_

#include <tvm/ir/module.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>

#include "../utils.h"

namespace tvm {
namespace relay {
namespace backend {
namespace aot {

/*! \brief Create FunctionInfo metadata for all the PrimFuncs in a module lowered
* for AOT execution.
* \param mod The module.
* \param workspace_byte_alignment The alignment of the workspace pool.
* \param constant_byte_alignment The alignment of the constant pool.
* \return A map between function names and FunctionInfos.
*/
Map<String, FunctionInfo> CreateFunctionMetadata(const IRModule& mod,
Integer workspace_byte_alignment,
Integer constant_byte_alignment);

} // namespace aot
} // namespace backend
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_BACKEND_AOT_CREATE_FUNCTION_METADATA_H_
6 changes: 3 additions & 3 deletions src/tir/usmp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size,
}

TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode);
TVM_REGISTER_GLOBAL("tir.usmp.AllocatedPoolInfo")
.set_body_typed([](PoolInfo pool_info, Integer allocated_size) {
return AllocatedPoolInfo(pool_info, allocated_size);
TVM_REGISTER_GLOBAL("ir.AllocatedPoolInfo")
.set_body_typed([](PoolInfo pool_info, Integer allocated_size, Integer pool_var_idx) {
return AllocatedPoolInfo(pool_info, allocated_size, pool_var_idx);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down
Loading

0 comments on commit 0b54cd2

Please sign in to comment.