Skip to content

Commit

Permalink
Add Relay option to link parameters into runtime Modules (apache#6917)
Browse files Browse the repository at this point in the history
* refactor RPCSessionContext utils

* Make TVMLogf platform-independent.

 * Some platforms need to use an alternate printf() to support basic
   things like %zu. Since %zu is platform-specific, we prefer to
   use a printf() that supports it or allow the platform to fix it up
   as needed.
  • Loading branch information
areusch authored and trevor-m committed Dec 4, 2020
1 parent 9aff41d commit 87223ba
Show file tree
Hide file tree
Showing 41 changed files with 1,927 additions and 116 deletions.
1 change: 1 addition & 0 deletions cmake/modules/StandaloneCrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ if(USE_MICRO)
"src/runtime/crt/include *.h -> include"
"src/runtime/crt/common *.c -> src/runtime/crt/common"
"src/runtime/crt/graph_runtime *.c -> src/runtime/crt/graph_runtime"
"src/runtime/crt/graph_runtime_module *.c -> src/runtime/crt/graph_runtime_module"
"src/runtime/crt/host crt_config.h -> src/runtime/crt/host"
"src/runtime/crt/utvm_rpc_common *.cc -> src/runtime/crt/utvm_rpc_common"
"src/runtime/crt/utvm_rpc_server *.cc -> src/runtime/crt/utvm_rpc_server"
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/runtime/crt/error_codes.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ typedef enum {
kTvmErrorCategoryWriteStream = 3,
kTvmErrorCategorySession = 4,
kTvmErrorCategoryPlatform = 5,
kTvmErrorCategoryGenerated = 6,
kTvmErrorCategoryGraphRuntime = 7,
kTvmErrorCategoryFunctionCall = 8,
} tvm_crt_error_category_t;

typedef enum {
Expand Down Expand Up @@ -74,6 +77,19 @@ typedef enum {
kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1),
kTvmErrorPlatformShutdown = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 2),

// Common error codes returned from generated functions.
kTvmErrorGeneratedInvalidStorageId = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGenerated, 0),

// Graph runtime
kTvmErrorGraphModuleAlreadyCreated = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGraphRuntime, 0),
kTvmErrorGraphModuleBadContext = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGraphRuntime, 1),
kTvmErrorGraphModuleNoSuchInput = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGraphRuntime, 2),

// Function Calls - common problems encountered calling functions.
kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0),
kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1),
kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2),

// System errors are always negative integers; this mask indicates presence of a system error.
// Cast tvm_crt_error_t to a signed integer to interpret the negative error code.
kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)),
Expand Down
16 changes: 14 additions & 2 deletions include/tvm/runtime/crt/graph_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,20 @@ typedef struct TVMGraphRuntime TVMGraphRuntime;
* \brief Allocate a new GraphRuntime with vmalloc and initialize it.
*
* \param sym_json JSON-encoded graph.
* \param m TVM Module that exposes the functions to call.
* \param module_handle TVM Module that exposes the functions to call.
* \param ctxs runtime execution context.
*/
TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, const struct TVMModule* m,
TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, TVMModuleHandle module_handle,
const TVMContext* ctxs);

int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime* runtime, const char* name);

/*!
* \brief get number of input tensors allocated.
* \return integer number of tensors available to use.
*/
int TVMGraphRuntime_GetNumInputs();

/*!
* \brief set input to the graph based on name.
* \param runtime The graph runtime.
Expand All @@ -77,6 +83,12 @@ int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime* runtime, const char* name);
*/
void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in);

/*!
* \brief get number of output tensors allocated.
* \return integer number of output tensors allocated.
*/
int TVMGraphRuntime_GetNumOutputs();

/*!
* \brief Return NDArray for given output index.
* \param runtime The graph runtime.
Expand Down
42 changes: 42 additions & 0 deletions include/tvm/runtime/crt/graph_runtime_module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file graph_runtime.h
* \brief Tiny graph runtime that can run graph containing only tvm PackedFunc.
*/
#ifndef TVM_RUNTIME_CRT_GRAPH_RUNTIME_MODULE_H_
#define TVM_RUNTIME_CRT_GRAPH_RUNTIME_MODULE_H_

#ifdef __cplusplus
extern "C" {
#endif

#include <tvm/runtime/crt/error_codes.h>

/*!
* \brief Register the "tvm.graph_runtime.create" constructor PackedFunc.
*/
tvm_crt_error_t TVMGraphRuntimeModule_Register();

#ifdef __cplusplus
} // extern "C"
#endif

#endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_MODULE_H_
8 changes: 8 additions & 0 deletions include/tvm/runtime/crt/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ typedef struct TVMModule {
const TVMFuncRegistry* registry;
} TVMModule;

/*!
* \brief Create a new module handle from the given TVMModule instance.
* \param mod The module instance to register.
* \param out_handle Pointer to recieve the newly-minted handle for this module.
* \return 0 on success, non-zero on error.
*/
int TVMModCreateFromCModule(const TVMModule* mod, TVMModuleHandle* out_handle);

/*! \brief Entry point for the system lib module. */
const TVMModule* TVMSystemLibEntryPoint(void);

Expand Down
4 changes: 4 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
/*! \brief Prefix for parameter symbols emitted into the main program. */
constexpr const char* tvm_param_prefix = "__tvm_param__";
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
} // namespace symbol

// implementations of inline functions.
Expand Down
48 changes: 48 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TIR_FUNCTION_H_

#include <tvm/ir/function.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
Expand Down Expand Up @@ -150,6 +151,42 @@ class PrimFunc : public BaseFunc {
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
};

/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief PrimFunc specific attribute names.
*
Expand Down Expand Up @@ -192,6 +229,17 @@ constexpr const char* kNoAlias = "tir.noalias";
* \note There can only be one entry function per module.
*/
constexpr const char* kIsEntryFunc = "tir.is_entry_func";

/*!
* \brief Parameters used in the module that should be linked by the codegen.
*
* Type: Map<String, LinkableParam>
*
* \note This should be present only on a function named
* tvm::target::packed_func::kLookupLinkedParam.
*/
constexpr const char* kLinkedParams = "tir.linked_params";

} // namespace attr
} // namespace tir
} // namespace tvm
Expand Down
23 changes: 21 additions & 2 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import re
from tvm.contrib import utils

from .micro_library import MicroLibrary


_LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,7 +111,13 @@ def default_options(target_include_dir):


def build_static_runtime(
workspace, compiler, module, lib_opts=None, bin_opts=None, generated_lib_opts=None
workspace,
compiler,
module,
lib_opts=None,
bin_opts=None,
generated_lib_opts=None,
extra_libs=None,
):
"""Build the on-device runtime, statically linking the given modules.
Expand All @@ -131,6 +139,12 @@ def build_static_runtime(
The `options` parameter passed to compiler.library() when compiling the generated TVM C
source module.
extra_libs : Optional[List[MicroLibrary|str]]
If specified, extra libraries to be compiled into the binary. If a MicroLibrary, it is
included into the binary directly. If a string, the path to a directory; all direct children
of this directory matching RUNTIME_SRC_REGEX are built into a library. These libraries are
placed before any common CRT libraries in the link order.
Returns
-------
MicroBinary :
Expand All @@ -150,7 +164,12 @@ def build_static_runtime(
module.save(mod_src_path, "cc")

libs = []
for lib_src_dir in RUNTIME_LIB_SRC_DIRS:
for mod_or_src_dir in (extra_libs or []) + RUNTIME_LIB_SRC_DIRS:
if isinstance(mod_or_src_dir, MicroLibrary):
libs.append(mod_or_src_dir)
continue

lib_src_dir = mod_or_src_dir
lib_name = os.path.basename(lib_src_dir)
lib_build_dir = workspace.relpath(f"build/{lib_name}")
os.makedirs(lib_build_dir)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/micro/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def read(self, n, timeout_sec):
raise base.IoTimeoutError()

def close(self):
pass # Pipes closed by parent class.
pass # Pipes closed by parent class (DebugWrapperTransport calls stop() next).

def transport(self):
return self._Transport(self)
Expand Down
73 changes: 72 additions & 1 deletion python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,43 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
self.transport.__exit__(exc_type, exc_value, exc_traceback)


def lookup_remote_linked_param(mod, storage_id, template_tensor, ctx):
"""Lookup a parameter that has been pre-linked into a remote (i.e. over RPC) Module.
This function signature matches the signature built by
Parameters
----------
mod : tvm.runtime.Module
The remote Module containing the pre-linked parameters.
storage_id : int
An integer identifying the pre-linked paramter to find
template_tensor : DLTensor
A DLTensor containing metadata that should be filled-in to the returned NDArray. This
function should mostly not inspect this, and just pass it along to
NDArrayFromRemoteOpaqueHandle.
ctx : TVMContext
The remote CPU context to be used with the returned NDArray.
Returns
-------
tvm.nd.NDArray :
NDArray containing the pre-linked parameter.
"""
try:
lookup_linked_param = mod.get_function("_lookup_linked_param")
except AttributeError:
return None

remote_data = lookup_linked_param(storage_id)
if remote_data is None:
return None

return get_global_func("tvm.rpc.NDArrayFromRemoteOpaqueHandle")(
mod, remote_data, template_tensor, ctx, lambda: None
)


def create_local_graph_runtime(graph_json_str, mod, ctx):
"""Create a local graph runtime driving execution on the remote CPU context given.
Expand All @@ -175,4 +212,38 @@ def create_local_graph_runtime(graph_json_str, mod, ctx):
"""
device_type_id = [ctx.device_type, ctx.device_id]
fcreate = get_global_func("tvm.graph_runtime.create")
return graph_runtime.GraphModule(fcreate(graph_json_str, mod, *device_type_id))
return graph_runtime.GraphModule(
fcreate(graph_json_str, mod, lookup_remote_linked_param, *device_type_id)
)


def create_local_debug_runtime(graph_json_str, mod, ctx, dump_root=None):
"""Create a local debug runtime driving execution on the remote CPU context given.
Parameters
----------
graph_json_str : str
A string containing the graph representation.
mod : tvm.runtime.Module
The remote module containing functions in graph_json_str.
ctx : tvm.Context
The remote CPU execution context.
dump_root : Optional[str]
If given, passed as dump_root= to GraphModuleDebug.
Returns
-------
tvm.contrib.GraphRuntime :
A local graph runtime instance that executes on the remote device.
"""
device_type_id = [ctx.device_type, ctx.device_id]
fcreate = get_global_func("tvm.graph_runtime_debug.create")
return debug_runtime.GraphModuleDebug(
fcreate(graph_json_str, mod, lookup_remote_linked_param, *device_type_id),
[ctx],
graph_json_str,
dump_root=dump_root,
)
2 changes: 1 addition & 1 deletion python/tvm/micro/transport/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class IoTimeoutError(Exception):
)


def debug_transport_timeouts(session_start_retry_timeout_sec=0.0):
def debug_transport_timeouts(session_start_retry_timeout_sec=0):
return TransportTimeouts(
session_start_retry_timeout_sec=session_start_retry_timeout_sec,
session_start_timeout_sec=0,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def micro(model="unknown", options=None):
"stm32f746xx": ["-mcpu=cortex-m7", "-march=armv7e-m"],
}
opts = _merge_opts(
trans_table[model] + ["-runtime=c", "--system-lib", f"-model={model}"], options
trans_table[model] + ["-runtime=c", "--system-lib", f"-model={model}"],
options,
)

# NOTE: in the future, the default micro target will be LLVM except when
Expand Down
Loading

0 comments on commit 87223ba

Please sign in to comment.