From 87223ba2bb5cda598cf80e8572ac287c2ba99e07 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 26 Nov 2020 06:09:22 -0800 Subject: [PATCH] Add Relay option to link parameters into runtime Modules (#6917) * refactor RPCSessionContext utils * Make TVMLogf platform-independent. * Some platforms need to use an alternate printf() to support basic things like %zu. Since %zu is platform-specific, we prefer to use a printf() that supports it or allow the platform to fix it up as needed. --- cmake/modules/StandaloneCrt.cmake | 1 + include/tvm/runtime/crt/error_codes.h | 16 + include/tvm/runtime/crt/graph_runtime.h | 16 +- .../tvm/runtime/crt/graph_runtime_module.h | 42 ++ include/tvm/runtime/crt/module.h | 8 + include/tvm/runtime/module.h | 4 + include/tvm/tir/function.h | 48 +++ python/tvm/micro/build.py | 23 +- python/tvm/micro/debugger.py | 2 +- python/tvm/micro/session.py | 73 +++- python/tvm/micro/transport/base.py | 2 +- python/tvm/target/target.py | 3 +- src/relay/backend/build_module.cc | 49 ++- src/relay/backend/graph_runtime_codegen.cc | 36 +- src/runtime/crt/Makefile | 2 +- src/runtime/crt/common/crt_runtime_api.c | 24 +- src/runtime/crt/common/memory.c | 13 +- src/runtime/crt/graph_runtime/graph_runtime.c | 89 +++- .../graph_runtime_module.c | 221 ++++++++++ src/runtime/crt/host/main.cc | 9 + .../internal/graph_runtime/graph_runtime.h | 10 +- .../graph/debug/graph_runtime_debug.cc | 15 +- src/runtime/graph/graph_runtime.cc | 87 +++- src/runtime/graph/graph_runtime.h | 25 +- src/runtime/graph/graph_runtime_factory.cc | 2 +- src/runtime/rpc/rpc_module.cc | 85 ++-- src/target/llvm/codegen_llvm.cc | 88 ++++ src/target/llvm/codegen_llvm.h | 12 + src/target/llvm/codegen_params.cc | 176 ++++++++ src/target/llvm/codegen_params.h | 49 +++ src/target/llvm/llvm_module.cc | 20 +- src/target/source/codegen_c_host.cc | 64 +++ src/target/source/codegen_c_host.h | 3 + src/target/source/codegen_params.cc | 248 +++++++++++ src/target/source/codegen_params.h | 52 +++ src/target/target_kind.cc | 2 + src/tir/ir/function.cc | 7 + tests/cpp/target_test.cc | 3 +- tests/python/unittest/test_crt.py | 1 + tests/python/unittest/test_link_params.py | 408 ++++++++++++++++++ .../unittest/test_target_codegen_llvm.py | 5 +- 41 files changed, 1927 insertions(+), 116 deletions(-) create mode 100644 include/tvm/runtime/crt/graph_runtime_module.h create mode 100644 src/runtime/crt/graph_runtime_module/graph_runtime_module.c create mode 100644 src/target/llvm/codegen_params.cc create mode 100644 src/target/llvm/codegen_params.h create mode 100644 src/target/source/codegen_params.cc create mode 100644 src/target/source/codegen_params.h create mode 100644 tests/python/unittest/test_link_params.py diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index 73c85d13e2ef..256ce2a48a6c 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -44,6 +44,7 @@ if(USE_MICRO) "src/runtime/crt/include *.h -> include" "src/runtime/crt/common *.c -> src/runtime/crt/common" "src/runtime/crt/graph_runtime *.c -> src/runtime/crt/graph_runtime" + "src/runtime/crt/graph_runtime_module *.c -> src/runtime/crt/graph_runtime_module" "src/runtime/crt/host crt_config.h -> src/runtime/crt/host" "src/runtime/crt/utvm_rpc_common *.cc -> src/runtime/crt/utvm_rpc_common" "src/runtime/crt/utvm_rpc_server *.cc -> src/runtime/crt/utvm_rpc_server" diff --git a/include/tvm/runtime/crt/error_codes.h b/include/tvm/runtime/crt/error_codes.h index 16d0e793848b..93a332a5924f 100644 --- a/include/tvm/runtime/crt/error_codes.h +++ b/include/tvm/runtime/crt/error_codes.h @@ -41,6 +41,9 @@ typedef enum { kTvmErrorCategoryWriteStream = 3, kTvmErrorCategorySession = 4, kTvmErrorCategoryPlatform = 5, + kTvmErrorCategoryGenerated = 6, + kTvmErrorCategoryGraphRuntime = 7, + kTvmErrorCategoryFunctionCall = 8, } tvm_crt_error_category_t; typedef enum { @@ -74,6 +77,19 @@ typedef enum { kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1), kTvmErrorPlatformShutdown = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 2), + // Common error codes returned from generated functions. + kTvmErrorGeneratedInvalidStorageId = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGenerated, 0), + + // Graph runtime + kTvmErrorGraphModuleAlreadyCreated = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGraphRuntime, 0), + kTvmErrorGraphModuleBadContext = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGraphRuntime, 1), + kTvmErrorGraphModuleNoSuchInput = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGraphRuntime, 2), + + // Function Calls - common problems encountered calling functions. + kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0), + kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1), + kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2), + // System errors are always negative integers; this mask indicates presence of a system error. // Cast tvm_crt_error_t to a signed integer to interpret the negative error code. kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)), diff --git a/include/tvm/runtime/crt/graph_runtime.h b/include/tvm/runtime/crt/graph_runtime.h index d2eb3b7785e9..e8413aa1723d 100644 --- a/include/tvm/runtime/crt/graph_runtime.h +++ b/include/tvm/runtime/crt/graph_runtime.h @@ -61,14 +61,20 @@ typedef struct TVMGraphRuntime TVMGraphRuntime; * \brief Allocate a new GraphRuntime with vmalloc and initialize it. * * \param sym_json JSON-encoded graph. - * \param m TVM Module that exposes the functions to call. + * \param module_handle TVM Module that exposes the functions to call. * \param ctxs runtime execution context. */ -TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, const struct TVMModule* m, +TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, TVMModuleHandle module_handle, const TVMContext* ctxs); int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime* runtime, const char* name); +/*! + * \brief get number of input tensors allocated. + * \return integer number of tensors available to use. + */ +int TVMGraphRuntime_GetNumInputs(); + /*! * \brief set input to the graph based on name. * \param runtime The graph runtime. @@ -77,6 +83,12 @@ int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime* runtime, const char* name); */ void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); +/*! + * \brief get number of output tensors allocated. + * \return integer number of output tensors allocated. + */ +int TVMGraphRuntime_GetNumOutputs(); + /*! * \brief Return NDArray for given output index. * \param runtime The graph runtime. diff --git a/include/tvm/runtime/crt/graph_runtime_module.h b/include/tvm/runtime/crt/graph_runtime_module.h new file mode 100644 index 000000000000..04e9184c8b8d --- /dev/null +++ b/include/tvm/runtime/crt/graph_runtime_module.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime.h + * \brief Tiny graph runtime that can run graph containing only tvm PackedFunc. + */ +#ifndef TVM_RUNTIME_CRT_GRAPH_RUNTIME_MODULE_H_ +#define TVM_RUNTIME_CRT_GRAPH_RUNTIME_MODULE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +/*! + * \brief Register the "tvm.graph_runtime.create" constructor PackedFunc. + */ +tvm_crt_error_t TVMGraphRuntimeModule_Register(); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_MODULE_H_ diff --git a/include/tvm/runtime/crt/module.h b/include/tvm/runtime/crt/module.h index 2359025f6fe1..7b124c4faa3a 100644 --- a/include/tvm/runtime/crt/module.h +++ b/include/tvm/runtime/crt/module.h @@ -39,6 +39,14 @@ typedef struct TVMModule { const TVMFuncRegistry* registry; } TVMModule; +/*! + * \brief Create a new module handle from the given TVMModule instance. + * \param mod The module instance to register. + * \param out_handle Pointer to recieve the newly-minted handle for this module. + * \return 0 on success, non-zero on error. + */ +int TVMModCreateFromCModule(const TVMModule* mod, TVMModuleHandle* out_handle); + /*! \brief Entry point for the system lib module. */ const TVMModule* TVMSystemLibEntryPoint(void); diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 0e7cd2b08784..04a5cf8bf25d 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -226,6 +226,10 @@ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier"; /*! \brief Placeholder for the module's entry function. */ constexpr const char* tvm_module_main = "__tvm_main__"; +/*! \brief Prefix for parameter symbols emitted into the main program. */ +constexpr const char* tvm_param_prefix = "__tvm_param__"; +/*! \brief A PackedFunc that looks up linked parameters by storage_id. */ +constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; } // namespace symbol // implementations of inline functions. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 64dbb5cf8ec3..97ee7f7211d4 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -25,6 +25,7 @@ #define TVM_TIR_FUNCTION_H_ #include +#include #include #include #include @@ -150,6 +151,42 @@ class PrimFunc : public BaseFunc { TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); }; +/*! + * \brief Describes one parameter that should be linked into the generated module. + * + * When parameters are to be linked in with generated code (i.e. on target_host-compatible + * backends), Relay attaches instances of this object to a global TIR function. Code-generators + * use the information contained in this node to include the parameter data in the generated + * module. + */ +class LinkedParamNode : public Object { + public: + /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */ + int64_t id; + + /*! \brief Parameter data which should get linked into the final module. */ + ::tvm::runtime::NDArray param; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("id", &id); + v->Visit("param", ¶m); + } + + static constexpr const char* _type_key = "tir.LinkedParam"; + TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object); +}; + +/*! + * \brief Managed reference to LinkedParamNode. + */ +class LinkedParam : public ObjectRef { + public: + TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param); + + TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); +}; + /*! * \brief PrimFunc specific attribute names. * @@ -192,6 +229,17 @@ constexpr const char* kNoAlias = "tir.noalias"; * \note There can only be one entry function per module. */ constexpr const char* kIsEntryFunc = "tir.is_entry_func"; + +/*! + * \brief Parameters used in the module that should be linked by the codegen. + * + * Type: Map + * + * \note This should be present only on a function named + * tvm::target::packed_func::kLookupLinkedParam. + */ +constexpr const char* kLinkedParams = "tir.linked_params"; + } // namespace attr } // namespace tir } // namespace tvm diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index d1a3c4163755..4aec9ea5ecbb 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -23,6 +23,8 @@ import re from tvm.contrib import utils +from .micro_library import MicroLibrary + _LOG = logging.getLogger(__name__) @@ -109,7 +111,13 @@ def default_options(target_include_dir): def build_static_runtime( - workspace, compiler, module, lib_opts=None, bin_opts=None, generated_lib_opts=None + workspace, + compiler, + module, + lib_opts=None, + bin_opts=None, + generated_lib_opts=None, + extra_libs=None, ): """Build the on-device runtime, statically linking the given modules. @@ -131,6 +139,12 @@ def build_static_runtime( The `options` parameter passed to compiler.library() when compiling the generated TVM C source module. + extra_libs : Optional[List[MicroLibrary|str]] + If specified, extra libraries to be compiled into the binary. If a MicroLibrary, it is + included into the binary directly. If a string, the path to a directory; all direct children + of this directory matching RUNTIME_SRC_REGEX are built into a library. These libraries are + placed before any common CRT libraries in the link order. + Returns ------- MicroBinary : @@ -150,7 +164,12 @@ def build_static_runtime( module.save(mod_src_path, "cc") libs = [] - for lib_src_dir in RUNTIME_LIB_SRC_DIRS: + for mod_or_src_dir in (extra_libs or []) + RUNTIME_LIB_SRC_DIRS: + if isinstance(mod_or_src_dir, MicroLibrary): + libs.append(mod_or_src_dir) + continue + + lib_src_dir = mod_or_src_dir lib_name = os.path.basename(lib_src_dir) lib_build_dir = workspace.relpath(f"build/{lib_name}") os.makedirs(lib_build_dir) diff --git a/python/tvm/micro/debugger.py b/python/tvm/micro/debugger.py index 8119940a018c..65cafe7e9c8a 100644 --- a/python/tvm/micro/debugger.py +++ b/python/tvm/micro/debugger.py @@ -272,7 +272,7 @@ def read(self, n, timeout_sec): raise base.IoTimeoutError() def close(self): - pass # Pipes closed by parent class. + pass # Pipes closed by parent class (DebugWrapperTransport calls stop() next). def transport(self): return self._Transport(self) diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 88bdf6cd8b5a..fba612b84d1f 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -154,6 +154,43 @@ def __exit__(self, exc_type, exc_value, exc_traceback): self.transport.__exit__(exc_type, exc_value, exc_traceback) +def lookup_remote_linked_param(mod, storage_id, template_tensor, ctx): + """Lookup a parameter that has been pre-linked into a remote (i.e. over RPC) Module. + + This function signature matches the signature built by + + Parameters + ---------- + mod : tvm.runtime.Module + The remote Module containing the pre-linked parameters. + storage_id : int + An integer identifying the pre-linked paramter to find + template_tensor : DLTensor + A DLTensor containing metadata that should be filled-in to the returned NDArray. This + function should mostly not inspect this, and just pass it along to + NDArrayFromRemoteOpaqueHandle. + ctx : TVMContext + The remote CPU context to be used with the returned NDArray. + + Returns + ------- + tvm.nd.NDArray : + NDArray containing the pre-linked parameter. + """ + try: + lookup_linked_param = mod.get_function("_lookup_linked_param") + except AttributeError: + return None + + remote_data = lookup_linked_param(storage_id) + if remote_data is None: + return None + + return get_global_func("tvm.rpc.NDArrayFromRemoteOpaqueHandle")( + mod, remote_data, template_tensor, ctx, lambda: None + ) + + def create_local_graph_runtime(graph_json_str, mod, ctx): """Create a local graph runtime driving execution on the remote CPU context given. @@ -175,4 +212,38 @@ def create_local_graph_runtime(graph_json_str, mod, ctx): """ device_type_id = [ctx.device_type, ctx.device_id] fcreate = get_global_func("tvm.graph_runtime.create") - return graph_runtime.GraphModule(fcreate(graph_json_str, mod, *device_type_id)) + return graph_runtime.GraphModule( + fcreate(graph_json_str, mod, lookup_remote_linked_param, *device_type_id) + ) + + +def create_local_debug_runtime(graph_json_str, mod, ctx, dump_root=None): + """Create a local debug runtime driving execution on the remote CPU context given. + + Parameters + ---------- + graph_json_str : str + A string containing the graph representation. + + mod : tvm.runtime.Module + The remote module containing functions in graph_json_str. + + ctx : tvm.Context + The remote CPU execution context. + + dump_root : Optional[str] + If given, passed as dump_root= to GraphModuleDebug. + + Returns + ------- + tvm.contrib.GraphRuntime : + A local graph runtime instance that executes on the remote device. + """ + device_type_id = [ctx.device_type, ctx.device_id] + fcreate = get_global_func("tvm.graph_runtime_debug.create") + return debug_runtime.GraphModuleDebug( + fcreate(graph_json_str, mod, lookup_remote_linked_param, *device_type_id), + [ctx], + graph_json_str, + dump_root=dump_root, + ) diff --git a/python/tvm/micro/transport/base.py b/python/tvm/micro/transport/base.py index 07a6a6ac7fdc..fdc7e9b2afce 100644 --- a/python/tvm/micro/transport/base.py +++ b/python/tvm/micro/transport/base.py @@ -64,7 +64,7 @@ class IoTimeoutError(Exception): ) -def debug_transport_timeouts(session_start_retry_timeout_sec=0.0): +def debug_transport_timeouts(session_start_retry_timeout_sec=0): return TransportTimeouts( session_start_retry_timeout_sec=session_start_retry_timeout_sec, session_start_timeout_sec=0, diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index c919fc31e9aa..edbb0fa3792a 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -236,7 +236,8 @@ def micro(model="unknown", options=None): "stm32f746xx": ["-mcpu=cortex-m7", "-march=armv7e-m"], } opts = _merge_opts( - trans_table[model] + ["-runtime=c", "--system-lib", f"-model={model}"], options + trans_table[model] + ["-runtime=c", "--system-lib", f"-model={model}"], + options, ) # NOTE: in the future, the default micro target will be LLVM except when diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ddea5456585b..82ac1c57018e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -22,6 +22,7 @@ * \brief Code generation for TVM's graph runtime. */ #include +#include #include #include #include @@ -30,6 +31,7 @@ #include +#include "../../target/func_registry_generator.h" #include "../../target/source/codegen_source_base.h" #include "compile_engine.h" #include "utils.h" @@ -88,6 +90,17 @@ struct GraphCodegen { return ret; } + std::unordered_map GetParamIds() { + std::unordered_map ret; + auto names = CallFunc>("list_params_name", nullptr); + for (const auto& expr : names) { + // Implicit cast from runtime::String to std::string + std::string key = expr; + ret[key] = CallFunc("get_param_id", key); + } + return ret; + } + protected: tvm::runtime::Module mod; template @@ -443,16 +456,36 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = graph_codegen_->GetIRModule(); - // When there is no lowered_funcs due to reasons such as optimization. - if (lowered_funcs.size() == 0) { - Target target_host = GetTargetHost(); + Target target_host = GetTargetHost(); + // If no target_host has been set, we choose a default one, which is + // llvm if "codegen.LLVMModuleCreate" is accessible. + const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); + if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); + + // Generate a placeholder function that attaches linked params as its arguments. + if (target_host->GetAttr("link-params").value_or(Bool(false))) { + CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; + auto param_ids = graph_codegen_->GetParamIds(); + auto link_params = Map(); + for (auto param : ret_.params) { + link_params.Set(param.first, tir::LinkedParam(param_ids[param.first], param.second)); + } - // If no target_host has been set, we choose a default one, which is - // llvm if "codegen.LLVMModuleCreate" is accessible. - const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); - if (!target_host.defined()) - target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); + Map dict; + dict.Set(tvm::tir::attr::kLinkedParams, link_params); + dict.Set(tvm::attr::kGlobalSymbol, String(::tvm::runtime::symbol::tvm_lookup_linked_param)); + DictAttrs attrs{dict}; + auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), + Map(), attrs); + if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) { + lowered_funcs.Set(target_host->str(), IRModule(Map({}))); + } + lowered_funcs[target_host->str()]->Add( + GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); + } + // When there is no lowered_funcs due to reasons such as optimization. + if (lowered_funcs.size() == 0) { if (target_host.defined() && target_host->kind->name == "llvm") { // If we can decide the target is LLVM, we then create an empty LLVM module. ret_.mod = (*pf)(target_host->str(), "empty_module"); diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index e24d18de931c..7ed150495104 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -56,7 +56,7 @@ struct LoweredOutput { std::string graph_json; Map lowered_funcs; Array external_mods; - std::unordered_map params; + std::unordered_map> params; }; /*! \brief Node types */ @@ -203,7 +203,12 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator>(); + for (auto param : params_) { + ret.params.emplace(std::make_pair( + param.first, + std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + } for (auto& kv : lowered_funcs_) { if (ret.lowered_funcs.count(kv.first) == 0) { @@ -312,9 +317,12 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator(op); size_t index = params_.size(); std::string name = "p" + std::to_string(index); - params_[name] = op->data; auto node = GraphInputNode::make_node_ptr(name, GraphAttrs()); - return AddNode(node, expr); + auto to_return = AddNode(node, expr); + CHECK_EQ(to_return.size(), 1) << "Expected exactly 1 parameter node created"; + param_storage_ids_[name] = storage_device_map_[expr][0][0]->value; + params_[name] = op->data; + return to_return; } std::vector VisitExpr_(const TupleNode* op) override { @@ -531,8 +539,14 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator> var_map_; /*! \brief target device */ TargetsMap targets_; - /*! \brief params */ + /*! + * \brief parameters (i.e. ConstantNodes found in the graph). + * These are take as inputs to the GraphRuntime. + * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be + * used to lookup the parameter. + */ std::unordered_map params_; + std::unordered_map param_storage_ids_; /*! \brief plan memory of device result */ Map> storage_device_map_; /*! \brief lowered funcs */ @@ -581,8 +595,16 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { } else if (name == "get_param_by_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { String key = args[0]; - ICHECK_GT(this->output_.params.count(key), 0); - *rv = this->output_.params[key]; + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + *rv = (*it).second.second; + }); + } else if (name == "get_param_id") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + *rv = (*it).second.first; }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { diff --git a/src/runtime/crt/Makefile b/src/runtime/crt/Makefile index 8a24db4e8b2b..6e462431173f 100644 --- a/src/runtime/crt/Makefile +++ b/src/runtime/crt/Makefile @@ -65,7 +65,7 @@ $(notdir $(1)): $${BUILD_DIR}/lib$(notdir $(1)).a endef -LIBS = src/runtime/crt/common src/runtime/crt/graph_runtime src/runtime/crt/utvm_rpc_common src/runtime/crt/utvm_rpc_server +LIBS = src/runtime/crt/common src/runtime/crt/graph_runtime src/runtime/crt/graph_runtime_module src/runtime/crt/utvm_rpc_common src/runtime/crt/utvm_rpc_server $(foreach lib,$(LIBS),$(eval $(call LIB_template,$(lib)))) diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index d6f78d9e3a03..f2d67ccfbeab 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -127,7 +127,7 @@ static TVMModuleHandle EncodeModuleHandle(tvm_module_index_t module_index) { return (TVMModuleHandle)((uintptr_t)(module_index | 0x8000)); } -static int TVMModCreateFromCModule(const TVMModule* mod, TVMModuleHandle* out_handle) { +int TVMModCreateFromCModule(const TVMModule* mod, TVMModuleHandle* out_handle) { tvm_module_index_t idx; for (idx = 0; idx < TVM_CRT_MAX_REGISTERED_MODULES; idx++) { @@ -229,17 +229,17 @@ int TVMFuncCall(TVMFunctionHandle func_handle, TVMValue* arg_values, int* type_c return func(arg_values, type_codes, num_args, ret_val, ret_type_code, resource_handle); } -static int FindFunctionOrSetAPIError(tvm_module_index_t module_index, - const TVMFuncRegistry* registry, const char* name, - TVMFunctionHandle* out) { +static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index, + const TVMFuncRegistry* registry, const char* name, + TVMFunctionHandle* out) { tvm_function_index_t function_index; - if (TVMFuncRegistry_Lookup(registry, name, &function_index) != 0) { - TVMAPIErrorf("failed to get function: mod_index=%04" PRIx16 ", name=%s", module_index, name); - return -1; + tvm_crt_error_t err = TVMFuncRegistry_Lookup(registry, name, &function_index); + if (err != kTvmErrorNoError) { + return err; } *out = EncodeFunctionHandle(module_index, function_index); - return 0; + return kTvmErrorNoError; } int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { @@ -279,6 +279,14 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r if (to_return == 0) { ret_type_codes[0] = kTVMPackedFuncHandle; + } else { + ret_value->v_handle = NULL; + } + + // NOTE: For compatibility with C++ runtime API, return no error (but NULL function) when the + // function lookup failed. + if (to_return == kTvmErrorFunctionNameNotFound) { + to_return = kTvmErrorNoError; } return to_return; diff --git a/src/runtime/crt/common/memory.c b/src/runtime/crt/common/memory.c index 68cad3645146..876c10efe3ea 100644 --- a/src/runtime/crt/common/memory.c +++ b/src/runtime/crt/common/memory.c @@ -151,8 +151,8 @@ void* MemoryManager_Alloc(MemoryManager* mgr, tvm_index_t size) { } vleak_size++; #if TVM_CRT_DEBUG > 1 - printf("allocate: addr=%p, start=%" PRId64 "/%zu, npage=%" PRId64 ", vleak=%d\n", data, start, - ptable->max_pages, npage, vleak_size); + TVMLogf("allocate: addr=%p, start=%" PRId64 "/%zu, npage=%" PRId64 ", vleak=%d\n", data, start, + ptable->max_pages, npage, vleak_size); #endif // TVM_CRT_DEBUG return data; } @@ -229,9 +229,8 @@ void* MemoryManager_Realloc(MemoryManager* mgr, void* ptr, tvm_index_t size) { vleak_size++; } #if TVM_CRT_DEBUG > 1 - printf("reallocate: addr=%p, start=%" PRId64 "/%zu, npage=%" PRId64 ", vleak=%d, size=%" PRId64 - "\n", - data, start, mgr->ptable.max_pages, npage, vleak_size, size); + TVMLogf("reallocate: addr=%p, start=%" PRId64 "/%zu, npage=%" PRId64 ", vleak=%d, size=%zu", data, + start, mgr->ptable.max_pages, npage, vleak_size, size); #endif // TVM_CRT_DEBUG return data; } @@ -251,8 +250,8 @@ void MemoryManager_Free(MemoryManager* mgr, void* ptr) { free_map->insert(free_map, p->num_pages, p); vleak_size--; #if TVM_CRT_DEBUG > 1 - printf("release: addr=%p, start=%" PRId64 "/%zu, npage=%" PRId64 ", vleak=%d\n", ptr, - entry->page.ptable_begin, mgr->ptable.max_pages, entry->page.num_pages, vleak_size); + TVMLogf("release: addr=%p, start=%" PRId64 "/%zu, npage=%zu, vleak=%d", ptr, + entry->page.ptable_begin, mgr->ptable.max_pages, entry->page.num_pages, vleak_size); #endif // TVM_CRT_DEBUG } diff --git a/src/runtime/crt/graph_runtime/graph_runtime.c b/src/runtime/crt/graph_runtime/graph_runtime.c index a6cd77ad6a22..450272d8722b 100644 --- a/src/runtime/crt/graph_runtime/graph_runtime.c +++ b/src/runtime/crt/graph_runtime/graph_runtime.c @@ -539,6 +539,13 @@ uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime* runtime, uint32_t nid, uint return runtime->node_row_ptr[nid] + index; } +/*! + * \brief Get the number of input tensors allocated. + * \param runtime The graph runtime. + * \return the number of input tensors allocated. + */ +int TVMGraphRuntime_GetNumInputs(TVMGraphRuntime* runtime) { return runtime->input_nodes_count; } + /*! * \brief Get the input index given the name of input. * \param runtime The graph runtime. @@ -675,6 +682,13 @@ void TVMGraphRuntime_Run(TVMGraphRuntime* runtime) { } } +/*! + * \brief Get the number of output tensors allocated. + * \param runtime The graph runtime. + * \return the number of output tensors allocated. + */ +int TVMGraphRuntime_GetNumOutputs(TVMGraphRuntime* runtime) { return runtime->outputs_count; } + int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out) { int status = 0; uint32_t nid = runtime->outputs[idx].node_id; @@ -693,8 +707,20 @@ int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTen } void TVMGraphRuntime_SetupStorage(TVMGraphRuntime* runtime) { + TVMPackedFunc lookup_linked_param; + int lookup_linked_param_valid; uint32_t idx; + { + TVMArgs temp_args; + temp_args.values[0].v_int64 = 0; + temp_args.tcodes[0] = kTVMArgInt; + temp_args.values_count = 1; + lookup_linked_param_valid = + (TVMPackedFunc_InitModuleFunc(&lookup_linked_param, runtime->module_handle, + "_lookup_linked_param", &temp_args) == 0); + } + // Grab saved optimization plan from graph. TVMGraphRuntimeGraphAttr* attrs = &(runtime->attrs); DLDataType* vtype = vmalloc(sizeof(DLDataType) * attrs->dltype_count); @@ -721,24 +747,47 @@ void TVMGraphRuntime_SetupStorage(TVMGraphRuntime* runtime) { if (sid >= pool_entry_count) { pool_entry_count = sid + 1; } + pool_entry[sid].entry_id = idx; pool_entry[sid].size = MAX(pool_entry[sid].size, bytes); pool_entry[sid].device_type = device_type; } // Allocate the space. for (idx = 0; idx < pool_entry_count; idx++) { - runtime->storage_pool = - vrealloc(runtime->storage_pool, sizeof(TVMNDArray) * (runtime->storage_pool_count + 1)); + runtime->storage_pool = vrealloc(runtime->storage_pool, sizeof(TVMGraphRuntimeStorageEntry) * + (runtime->storage_pool_count + 1)); TVMGraphRuntimePoolEntry pit = pool_entry[idx]; - int64_t shape[TVM_CRT_MAX_NDIM] = { - 0, - }; TVMContext ctx = runtime->ctxs[0]; - DLDataType dtype = {kDLFloat, 32, 1}; - shape[0] = (pit.size + 3) / 4; - runtime->storage_pool[runtime->storage_pool_count] = TVMNDArray_Empty(1, shape, dtype, ctx); - CHECK_NE(runtime->storage_pool[runtime->storage_pool_count].dl_tensor.data, 0, - "fail to create storage_pool with idx=%d\n", idx); + uint8_t did_find_linked_param = 0; + if (lookup_linked_param_valid) { + lookup_linked_param.args.values[0].v_int64 = idx; + CHECK_EQ(lookup_linked_param.Call(&lookup_linked_param), 0, "lookup_linked_param"); + + void* linked_param_data = lookup_linked_param.ret_value.values[0].v_handle; + if (linked_param_data != NULL) { + runtime->storage_pool[runtime->storage_pool_count].is_linked_param = 1; + DLTensor* tensor = &runtime->storage_pool[runtime->storage_pool_count].array.dl_tensor; + tensor->data = linked_param_data; + tensor->ctx = ctx; + tensor->ndim = attrs->ndim[pit.entry_id]; + tensor->shape = attrs->shape + idx * TVM_CRT_MAX_NDIM; + tensor->strides = NULL; + tensor->byte_offset = 0; + did_find_linked_param = 1; + } + } + if (did_find_linked_param == 0) { + int64_t shape[TVM_CRT_MAX_NDIM] = { + 0, + }; + DLDataType dtype = {kDLFloat, 32, 1}; + shape[0] = (pit.size + 3) / 4; + runtime->storage_pool[runtime->storage_pool_count].is_linked_param = 0; + runtime->storage_pool[runtime->storage_pool_count].array = + TVMNDArray_Empty(1, shape, dtype, ctx); + CHECK_NE(runtime->storage_pool[runtime->storage_pool_count].array.dl_tensor.data, 0, + "fail to create storage_pool with idx=%d\n", idx); + } runtime->storage_pool_count++; } @@ -751,7 +800,7 @@ void TVMGraphRuntime_SetupStorage(TVMGraphRuntime* runtime) { uint32_t storage_id = attrs->storage_id[idx]; CHECK(storage_id < runtime->storage_pool_count); runtime->data_entry[idx] = - TVMNDArray_CreateView(&(runtime->storage_pool[storage_id]), + TVMNDArray_CreateView(&(runtime->storage_pool[storage_id].array), attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx], vtype[idx]); CHECK_NE(runtime->data_entry[idx].dl_tensor.data, 0, "fail to create for node with idx=%d, storage_id=%u\n", idx, storage_id); @@ -858,28 +907,28 @@ int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime* runtime, const TVMOpParam* /*! * \brief Initialize the graph executor with graph and context. * \param graph_json The execution graph. - * \param module The module containing the compiled functions for the host + * \param module_handle The module containing the compiled functions for the host * processor. * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ -void TVMGraphRuntime_Init(TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module, - const TVMContext* ctxs) { +void TVMGraphRuntime_Init(TVMGraphRuntime* runtime, const char* graph_json, + TVMModuleHandle module_handle, const TVMContext* ctxs) { JSONReader reader = JSONReader_Create(graph_json); TVMGraphRuntime_Load(runtime, &reader); JSONReader_Release(&reader); + runtime->module_handle = module_handle; runtime->ctxs[0] = ctxs[0]; TVMGraphRuntime_SetupStorage(runtime); TVMGraphRuntime_SetupOpExecs(runtime); } -TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, const TVMModule* m, +TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, TVMModuleHandle module_handle, const TVMContext* ctxs) { - CHECK_EQ(vleak_size, 1, "memory leak checking won't work with concurrent CRT use"); TVMGraphRuntime* runtime = (TVMGraphRuntime*)vmalloc(sizeof(TVMGraphRuntime)); // NOLINT(*) memset(runtime, 0, sizeof(TVMGraphRuntime)); // init - TVMGraphRuntime_Init(runtime, sym_json, m, ctxs); + TVMGraphRuntime_Init(runtime, sym_json, module_handle, ctxs); return runtime; } @@ -892,7 +941,9 @@ void TVMGraphRuntime_Release(TVMGraphRuntime** pptr) { vfree(runtime->nodes); TVMGraphRuntimeGraphAttr_Release(&(runtime->attrs)); for (idx = 0; idx < runtime->storage_pool_count; ++idx) { - TVMNDArray_Release(&(runtime->storage_pool[idx])); + if (runtime->storage_pool[idx].is_linked_param == 0) { + TVMNDArray_Release(&(runtime->storage_pool[idx].array)); + } } for (idx = 0; idx < runtime->data_entry_count; ++idx) { vfree(runtime->data_entry[idx].dl_tensor.shape); @@ -909,6 +960,4 @@ void TVMGraphRuntime_Release(TVMGraphRuntime** pptr) { vfree(g_fexecs); g_fexecs = 0; } - - CHECK_EQ(vleak_size, 1, "found memory leak, leak size=%d", vleak_size - 1); } diff --git a/src/runtime/crt/graph_runtime_module/graph_runtime_module.c b/src/runtime/crt/graph_runtime_module/graph_runtime_module.c new file mode 100644 index 000000000000..2a32a0251507 --- /dev/null +++ b/src/runtime/crt/graph_runtime_module/graph_runtime_module.c @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE + +/*! + * \file graph_runtime_module.c + * \brief wrap graph_runtime into a TVMModule for use with RPC. + */ + +#include +#include +#include +#include + +#include "tvm/runtime/crt/internal/graph_runtime/graph_runtime.h" + +typedef struct { + TVMModule mod; + TVMGraphRuntime* runtime; +} GraphRuntimeModule; + +static GraphRuntimeModule graph_runtime; + +int32_t TVMGraphRuntimeModule_Create(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (graph_runtime.runtime != NULL) { + return kTvmErrorGraphModuleAlreadyCreated; + } + + if (nargs != 4) { + return kTvmErrorFunctionCallNumArguments; + } + + if (tcodes[0] != kTVMStr || tcodes[1] != kTVMModuleHandle || tcodes[2] != kTVMArgInt || + tcodes[3] != kTVMArgInt) { + return kTvmErrorFunctionCallWrongArgType; + } + + if (args[2].v_int64 != kDLCPU || args[3].v_int64 != 0) { + return kTvmErrorGraphModuleBadContext; + } + + TVMContext ctx = {(DLDeviceType)args[2].v_int64, (int)args[3].v_int64}; + graph_runtime.runtime = TVMGraphRuntime_Create(args[0].v_str, args[1].v_handle, &ctx); + + TVMModuleHandle out; + int ret_value = TVMModCreateFromCModule(&graph_runtime.mod, &out); + if (ret_value != 0) { + ret_tcodes[0] = kTVMNullptr; + TVMGraphRuntime_Release(&graph_runtime.runtime); + return ret_value; + } + + ret_values[0].v_handle = out; + ret_tcodes[0] = kTVMModuleHandle; + return kTvmErrorNoError; +} + +int32_t TVMGraphRuntimeModule_GetInput(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (nargs != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + if (tcodes[0] != kTVMStr) { + return kTvmErrorFunctionCallWrongArgType; + } + + int index = TVMGraphRuntime_GetInputIndex(graph_runtime.runtime, args[0].v_str); + if (index < 0) { + return kTvmErrorGraphModuleNoSuchInput; + } + + uint32_t eid = TVMGraphRuntime_GetEntryId(graph_runtime.runtime, + graph_runtime.runtime->input_nodes[index], 0); + ret_values[0].v_handle = (void*)&graph_runtime.runtime->data_entry[eid].dl_tensor; + ret_tcodes[0] = kTVMNDArrayHandle; + return 0; +} + +int32_t TVMGraphRuntimeModule_GetNumInputs(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 0) { + return kTvmErrorFunctionCallNumArguments; + } + + ret_values[0].v_int64 = TVMGraphRuntime_GetNumInputs(); + ret_tcodes[0] = kTVMArgInt; + return 0; +} + +int32_t TVMGraphRuntimeModule_GetNumOutputs(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 0) { + return kTvmErrorFunctionCallNumArguments; + } + + ret_values[0].v_int64 = TVMGraphRuntime_GetNumOutputs(graph_runtime.runtime); + ret_tcodes[0] = kTVMArgInt; + return 0; +} + +int32_t TVMGraphRuntimeModule_GetOutput(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + if (tcodes[0] != kTVMArgInt) { + return kTvmErrorFunctionCallWrongArgType; + } + + int output_index = args[0].v_int64; + if (output_index < 0 || output_index > TVMGraphRuntime_GetNumOutputs(graph_runtime.runtime)) { + return kTvmErrorGraphModuleNoSuchInput; + } + + uint32_t nid = graph_runtime.runtime->outputs[output_index].node_id; + uint32_t index = graph_runtime.runtime->outputs[output_index].index; + uint32_t eid = TVMGraphRuntime_GetEntryId(graph_runtime.runtime, nid, index); + + ret_values[0].v_handle = (void*)&(graph_runtime.runtime->data_entry[eid].dl_tensor); + ret_tcodes[0] = kTVMNDArrayHandle; + return 0; +} + +int32_t TVMGraphRuntimeModule_LoadParams(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + if (tcodes[0] != kTVMBytes) { + return kTvmErrorFunctionCallWrongArgType; + } + + ret_tcodes[0] = kTVMNullptr; + + TVMByteArray* arr = (TVMByteArray*)args[0].v_handle; + return TVMGraphRuntime_LoadParams(graph_runtime.runtime, arr->data, arr->size); +} + +int32_t TVMGraphRuntimeModule_Run(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (nargs != 0) { + return kTvmErrorFunctionCallNumArguments; + } + + TVMGraphRuntime_Run(graph_runtime.runtime); + + ret_tcodes[0] = kTVMNullptr; + return 0; +} + +int32_t TVMGraphRuntimeModule_SetInput(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (nargs != 2) { + return kTvmErrorFunctionCallNumArguments; + } + + if (tcodes[0] != kTVMStr || tcodes[1] != kTVMDLTensorHandle) { + return kTvmErrorFunctionCallWrongArgType; + } + + TVMGraphRuntime_SetInput(graph_runtime.runtime, args[0].v_str, (DLTensor*)args[1].v_handle); + + ret_tcodes[0] = kTVMNullptr; + return 0; +} + +int32_t TVMGraphRuntimeModule_NotImplemented(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + return kTvmErrorFunctionCallNotImplemented; +} + +static const TVMBackendPackedCFunc graph_runtime_registry_funcs[] = { + &TVMGraphRuntimeModule_GetInput, &TVMGraphRuntimeModule_GetNumInputs, + &TVMGraphRuntimeModule_GetNumOutputs, &TVMGraphRuntimeModule_GetOutput, + &TVMGraphRuntimeModule_LoadParams, &TVMGraphRuntimeModule_Run, + &TVMGraphRuntimeModule_SetInput, &TVMGraphRuntimeModule_NotImplemented, +}; + +static const TVMFuncRegistry graph_runtime_registry = { + "\x08get_input\0" + "get_num_inputs\0" + "get_num_outputs\0" + "get_output\0" + "load_params\0" + "run\0" + "set_input\0" + "share_params\0", + graph_runtime_registry_funcs}; + +tvm_crt_error_t TVMGraphRuntimeModule_Register() { + graph_runtime.mod.registry = &graph_runtime_registry; + graph_runtime.runtime = NULL; + + return TVMFuncRegisterGlobal("tvm.graph_runtime.create", &TVMGraphRuntimeModule_Create, 0); +} diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index 664dae7ab857..41f2dc3b0a1b 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -32,6 +32,10 @@ #include "crt_config.h" +#ifdef TVM_HOST_USE_GRAPH_RUNTIME_MODULE +#include +#endif + using namespace std::chrono; extern "C" { @@ -95,6 +99,11 @@ int main(int argc, char** argv) { utvm_rpc_server_t rpc_server = UTvmRpcServerInit(memory, sizeof(memory), 8, &UTvmWriteFunc, nullptr); +#ifdef TVM_HOST_USE_GRAPH_RUNTIME_MODULE + CHECK_EQ(TVMGraphRuntimeModule_Register(), kTvmErrorNoError, + "failed to register GraphRuntime TVMModule"); +#endif + if (TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server, 0)) { fprintf(stderr, "utvm runtime: internal error registering global packedfunc; exiting\n"); diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/graph_runtime.h b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/graph_runtime.h index 7ea7a4f035c8..8e0faaa4f199 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/graph_runtime.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/graph_runtime.h @@ -33,6 +33,7 @@ typedef struct TVMGraphRuntimePoolEntry { size_t size; int device_type; + int entry_id; } TVMGraphRuntimePoolEntry; // Node entry @@ -44,6 +45,12 @@ typedef struct TVMGraphRuntimeNodeEntry { void (*Load)(JSONReader* reader); } TVMGraphRuntimeNodeEntry; +// Storage entry. +typedef struct TVMGraphRuntimeStorageEntry { + uint8_t is_linked_param; + TVMNDArray array; +} TVMGraphRuntimeStorageEntry; + // Node typedef struct TVMGraphRuntimeNode { // operator type in string @@ -87,7 +94,7 @@ typedef struct TVMGraphRuntime { TVMContext ctxs[1]; uint32_t ctxs_count; /*! \brief Common storage pool for all devices. */ - TVMNDArray* storage_pool; + TVMGraphRuntimeStorageEntry* storage_pool; uint32_t storage_pool_count; /*! \brief Data entry of each node. */ TVMNDArray* data_entry; @@ -100,6 +107,7 @@ typedef struct TVMGraphRuntime { typedef DLTensor* DLTensorPtr; // private functions +uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime* runtime, uint32_t nid, uint32_t index); void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size); diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 3e9ff4f279e7..d02a6d9a0d64 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -202,9 +202,10 @@ PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name, * \param ctxs All devices contexts. */ Module GraphRuntimeDebugCreate(const std::string& sym_json, const tvm::runtime::Module& m, - const std::vector& ctxs) { + const std::vector& ctxs, + PackedFunc lookup_linked_param_func) { auto exec = make_object(); - exec->Init(sym_json, m, ctxs); + exec->Init(sym_json, m, ctxs, lookup_linked_param_func); return Module(exec); } @@ -212,7 +213,15 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create").set_body([](TVMArgs args, ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " "at least 4, but it has " << args.num_args; - *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); + PackedFunc lookup_linked_param_func; + int ctx_start_arg = 2; + if (args[2].type_code() == kTVMPackedFuncHandle) { + lookup_linked_param_func = args[2]; + ctx_start_arg++; + } + + *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), + lookup_linked_param_func); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 21960d9d4b1b..9e1670e67fc0 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -64,14 +64,20 @@ void GraphRuntime::Run() { * processor. * \param ctxs The context of the host and devices where graph nodes will be * executed on. + * \param lookup_linked_param_func Linked parameter lookup function. */ void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs) { + const std::vector& ctxs, PackedFunc lookup_linked_param_func) { std::istringstream is(graph_json); dmlc::JSONReader reader(&is); this->Load(&reader); module_ = module; ctxs_ = ctxs; + lookup_linked_param_ = lookup_linked_param_func; + if (lookup_linked_param_ == nullptr) { + lookup_linked_param_ = PackedFunc( + [this](TVMArgs args, TVMRetValue* rv) { this->DefaultLookupLinkedParam(args, rv); }); + } this->SetupStorage(); this->SetupOpExecs(); for (size_t i = 0; i < input_nodes_.size(); i++) { @@ -286,6 +292,43 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { this->SetupOpExecs(); } +void GraphRuntime::LinkedNDArrayDeleter(Object* container) { + // container is the NDArray::Container which needs to get deleted. + // The data member points to global const memory, so it does not need deleting. + delete static_cast(container); +} + +void GraphRuntime::DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv) { + Module mod = args[0]; + int64_t storage_id = args[1]; + DLTensor* template_tensor = args[2]; + TVMContext ctx = args[3]; + // Get pre-linked parameter lookup function, if it was generated. When pf == nullptr, no linked + // params are present. + if (!module_lookup_linked_param_valid_) { + module_lookup_linked_param_ = + mod.GetFunction(::tvm::runtime::symbol::tvm_lookup_linked_param, true); + } + if (module_lookup_linked_param_ == nullptr) { + *rv = nullptr; + return; + } + + TVMRetValue opaque_handle = module_lookup_linked_param_(storage_id); + if (opaque_handle.type_code() == kTVMNullptr) { + *rv = nullptr; + return; + } + + std::vector shape_vec{template_tensor->shape, + template_tensor->shape + template_tensor->ndim}; + + std::unique_ptr container{new NDArray::Container( + static_cast(opaque_handle), shape_vec, template_tensor->dtype, ctx)}; + container->SetDeleter(GraphRuntime::LinkedNDArrayDeleter); + *rv = NDArray(GetObjectPtr(container.release())); +} + void GraphRuntime::SetupStorage() { // Grab saved optimization plan from graph. std::vector vtype; @@ -320,21 +363,37 @@ void GraphRuntime::SetupStorage() { ICHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type) << "The same pool entry cannot be assigned to multiple devices"; } + TVMRetValue lookup_rv; + { + std::vector shape_vec{attrs_.shape[i].begin(), attrs_.shape[i].end()}; + DLTensor template_tensor{nullptr, TVMContext{kDLCPU, 0}, static_cast(shape_vec.size()), + vtype[i], shape_vec.data(), nullptr, + 0}; + lookup_rv = lookup_linked_param_(module_, sid, &template_tensor, ctxs_[0]); + } + if (lookup_rv.type_code() != kTVMNullptr) { + pool_entry[sid].linked_param = lookup_rv; + } + pool_entry[sid].param_data_entry = i; pool_entry[sid].size = std::max(pool_entry[sid].size, bytes); pool_entry[sid].device_type = device_type; } // Allocate the space. for (const auto& pit : pool_entry) { - std::vector shape; // This for loop is very fast since there are usually only a couple of // devices available on the same hardware. const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { return pit.device_type == static_cast(c.device_type); }); TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit; - shape.push_back(static_cast(pit.size + 3) / 4); - storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); + if (pit.linked_param.defined()) { + storage_pool_.push_back(pit.linked_param); + } else { + std::vector shape; + shape.push_back(static_cast(pit.size + 3) / 4); + storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); + } } // Assign the pooled entries. A unified memory pool is used to simplifiy @@ -346,6 +405,7 @@ void GraphRuntime::SetupStorage() { int storage_id = attrs_.storage_id[i]; ICHECK_LT(static_cast(storage_id), storage_pool_.size()); data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); + const DLTensor* tmp = data_entry_[i].operator->(); data_alignment_[i] = details::GetDataAlignment(*tmp); } @@ -504,18 +564,19 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, } Module GraphRuntimeCreate(const std::string& sym_json, const tvm::runtime::Module& m, - const std::vector& ctxs) { + const std::vector& ctxs, + const PackedFunc lookup_linked_param_func) { auto exec = make_object(); - exec->Init(sym_json, m, ctxs); + exec->Init(sym_json, m, ctxs, lookup_linked_param_func); return Module(exec); } // Get all context for the host and other runtime devices. -std::vector GetAllContext(const TVMArgs& args) { +std::vector GetAllContext(const TVMArgs& args, int ctx_start_arg) { // Reserve the first item as the fallback device. std::vector ret; TVMContext ctx; - for (int i = 2; i < args.num_args; i += 2) { + for (int i = ctx_start_arg; i < args.num_args; i += 2) { int dev_type = args[i]; ctx.device_type = static_cast(dev_type); ctx.device_id = args[i + 1]; @@ -533,8 +594,14 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create").set_body([](TVMArgs args, TVMRet ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " "at least 4, but it has " << args.num_args; - const auto& contexts = GetAllContext(args); - *rv = GraphRuntimeCreate(args[0], args[1], contexts); + PackedFunc lookup_linked_param_func; + int ctx_start_arg = 2; + if (args[2].type_code() == kTVMPackedFuncHandle) { + lookup_linked_param_func = args[2]; + ctx_start_arg++; + } + const auto& contexts = GetAllContext(args, ctx_start_arg); + *rv = GraphRuntimeCreate(args[0], args[1], contexts, lookup_linked_param_func); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index c08f5e671a08..81aa87d6ed90 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -94,10 +94,13 @@ class TVM_DLL GraphRuntime : public ModuleNode { * processor. * \param ctxs The context of the host and devices where graph nodes will be * executed on. + * \param lookup_linked_param_func If given, a PackedFunc invoked to lookup linked parameters + * by storage_id. If not given, linked parameters are looked-up using an internal implementation, + * which is not compatible with RPCModules. */ void Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs); + const std::vector& ctxs, const PackedFunc lookup_linked_param_func); /*! * \brief Get the input index given the name of input. @@ -209,7 +212,10 @@ class TVM_DLL GraphRuntime : public ModuleNode { struct PoolEntry { size_t size; int device_type; - PoolEntry(int s, int dev_type) : size(s), device_type(dev_type) {} + int param_data_entry; + NDArray linked_param; + // PoolEntry(int s, int dev_type, void* pre_linked_param) : + // size(s), device_type(dev_type), pre_linked_param(std::move(pre_linked_param)) {} }; // Node entry struct NodeEntry { @@ -390,6 +396,10 @@ class TVM_DLL GraphRuntime : public ModuleNode { } ICHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format"; } + /*! \brief PackedFunc to lookup a linked paramter from a local Module. */ + void DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv); + /*! \brief Delete NDArray::Container with linked (i.e. static) data. */ + static void LinkedNDArrayDeleter(Object* container); /*! \brief Setup the temporal storage */ void SetupStorage(); /*! \brief Setup the executors. */ @@ -437,9 +447,18 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector data_alignment_; /*! \brief Operator on each node. */ std::vector> op_execs_; + /*! \brief Linked parameter lookup function. */ + PackedFunc lookup_linked_param_; + /*! \brief Module's _lookup_linked_param function, used by DefaultLookupLinkedParam. */ + PackedFunc module_lookup_linked_param_; + /*! + * \brief True when module_lookup_linked_param_ is valid. + * When the module does not include linked parmeters, module_lookup_linked_param_ will be nullptr. + */ + bool module_lookup_linked_param_valid_; }; -std::vector GetAllContext(const TVMArgs& args); +std::vector GetAllContext(const TVMArgs& args, int ctx_start_arg); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 632a25c987bc..2c055e16cc9f 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -97,7 +97,7 @@ void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { Module GraphRuntimeFactory::RuntimeCreate(const std::vector& ctxs) { auto exec = make_object(); - exec->Init(this->graph_json_, this->imports_[0], ctxs); + exec->Init(this->graph_json_, this->imports_[0], ctxs, PackedFunc()); // set params SetParams(exec.get(), this->params_); return Module(exec); diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 165c0fe73b36..4f721e122a4c 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -22,6 +22,7 @@ * \brief RPC runtime module. */ #include +#include #include #include @@ -36,6 +37,44 @@ namespace tvm { namespace runtime { +// deleter of RPC remote array +static void RemoteNDArrayDeleter(Object* obj) { + auto* ptr = static_cast(obj); + RemoteSpace* space = static_cast(ptr->dl_tensor.data); + if (ptr->manager_ctx != nullptr) { + space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); + } + delete space; + delete ptr; +} + +/*! + * \brief Build a local NDArray with remote backing storage. + * \param sess the RPCSession which owns the given handle. + * \param handle A pointer valid on the remote end which should form the `data` field of the + * underlying DLTensor. + * \param template_tensor An empty DLTensor whose shape and dtype fields are used to fill the newly + * created array. Needed because it's difficult to pass a shape vector as a PackedFunc arg. + * \param ctx Remote context used with this tensor. Must have non-zero RPCSessMask. + * \param remote_ndarray_handle The handle returned by RPC server to identify the NDArray. + */ +NDArray NDArrayFromRemoteOpaqueHandle(std::shared_ptr sess, void* handle, + DLTensor* template_tensor, TVMContext ctx, + void* remote_ndarray_handle) { + ICHECK_EQ(sess->table_index(), GetRPCSessionIndex(ctx)) + << "The TVMContext given does not belong to the given session"; + RemoteSpace* space = new RemoteSpace(); + space->sess = sess; + space->data = handle; + std::vector shape_vec{template_tensor->shape, + template_tensor->shape + template_tensor->ndim}; + NDArray::Container* data = new NDArray::Container(static_cast(space), std::move(shape_vec), + template_tensor->dtype, ctx); + data->manager_ctx = remote_ndarray_handle; + data->SetDeleter(RemoteNDArrayDeleter); + return NDArray(GetObjectPtr(data)); +} + /*! * \brief A wrapped remote function as a PackedFunc. */ @@ -113,41 +152,6 @@ class RPCWrappedFunc : public Object { << "Can not pass in context with a different remote session"; return RemoveRPCSessionMask(ctx); } - - // deleter of RPC remote array - static void RemoteNDArrayDeleter(Object* obj) { - auto* ptr = static_cast(obj); - RemoteSpace* space = static_cast(ptr->dl_tensor.data); - space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); - delete space; - delete ptr; - } - - // wrap return value as remote NDArray. - NDArray WrapRemoteNDArray(DLTensor* tensor, void* nd_handle) const { - NDArray::Container* data = new NDArray::Container(); - data->manager_ctx = nd_handle; - data->SetDeleter(RemoteNDArrayDeleter); - RemoteSpace* space = new RemoteSpace(); - space->sess = sess_; - space->data = tensor->data; - data->dl_tensor.data = space; - NDArray ret(GetObjectPtr(data)); - // RAII now in effect - data->shape_ = std::vector(tensor->shape, tensor->shape + tensor->ndim); - data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); - data->dl_tensor.ndim = static_cast(data->shape_.size()); - // setup dtype - data->dl_tensor.dtype = tensor->dtype; - // setup ctx, encode as remote session - data->dl_tensor.ctx = AddRPCSessionMask(tensor->ctx, sess_->table_index()); - // check strides. - ICHECK(tensor->strides == nullptr); - // setup byteoffset - data->dl_tensor.byte_offset = tensor->byte_offset; - - return ret; - } }; // RPC that represents a remote module session. @@ -280,7 +284,9 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) cons ICHECK_EQ(args.size(), 3); DLTensor* tensor = args[1]; void* nd_handle = args[2]; - *rv = WrapRemoteNDArray(tensor, nd_handle); + *rv = NDArrayFromRemoteOpaqueHandle(sess_, tensor->data, tensor, + AddRPCSessionMask(tensor->ctx, sess_->table_index()), + nd_handle); } else { ICHECK_EQ(args.size(), 2); *rv = args[1]; @@ -466,5 +472,12 @@ TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body([](TVMArgs args, TVMRetValue* *rv = static_cast(m.operator->())->sess()->table_index(); }); +TVM_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle") + .set_body_typed([](Module mod, void* remote_array, DLTensor* template_tensor, TVMContext ctx, + void* ndarray_handle) -> NDArray { + return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor, + ctx, ndarray_handle); + }); + } // namespace runtime } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index faa483d019c0..d10ed311949c 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -25,6 +25,7 @@ #include "codegen_llvm.h" #include +#include #include #include @@ -32,7 +33,10 @@ #include "../../arith/pattern_match.h" #include "../build_common.h" +#include "../func_registry_generator.h" #include "codegen_cpu.h" +#include "codegen_params.h" +#include "llvm/Support/raw_os_ostream.h" namespace tvm { namespace codegen { @@ -184,6 +188,90 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } +void CodeGenLLVM::LinkParameters(const Map params) { + // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, + // but they are at a different layer in the compiler... + std::vector param_types; + // args + param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); + // tcodes + param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); + // num_args + param_types.push_back(t_int_); + // ret_args + param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); + // ret_tcodes + param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); + // resource_handle + param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); + + llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false); + + llvm::Function* function = + llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + ::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get()); + function->setCallingConv(llvm::CallingConv::C); + function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); + + llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function); + builder_->SetInsertPoint(entry); + std::vector zero_index_list{llvm::ConstantInt::get(t_int32_, 0)}; + std::vector zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0), + llvm::ConstantInt::get(t_int32_, 0)}; + auto args_array = builder_->CreateBitCast( +#if TVM_LLVM_VERSION >= 50 + &function->arg_begin()[0], +#else + &(*(function->arg_begin())), +#endif + llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)); + llvm::Value* sid = builder_->CreateBitCast( + builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()), + builder_->CreateInBoundsGEP(args_array, zero_index_list)), + t_int64_); + + llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function); + auto ret_types_array = builder_->CreateBitCast( +#if TVM_LLVM_VERSION >= 50 + &function->arg_begin()[4], +#else + &(*(std::next(function->arg_begin(), 4))), +#endif + llvm::ArrayType::get(t_int_, 1)->getPointerTo()); + auto retval_array = builder_->CreateBitCast( +#if TVM_LLVM_VERSION >= 50 + &function->arg_begin()[3], +#else + &(*std::next(function->arg_begin(), 3)), +#endif + llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)->getPointerTo()); + llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1); + + builder_->SetInsertPoint(default_block); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), + builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateRet(ConstInt32(kTvmErrorNoError)); + + // Add data to the global section. + for (auto kv : params) { + auto array = NDArrayToLLVMArray(ctx_, kv.second->param); + std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; + llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( + *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + + llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); + switch_inst->addCase( + llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); + builder_->SetInsertPoint(case_block); + builder_->CreateStore( + builder_->CreatePointerCast(param_symbol, t_void_->getPointerTo(GetGlobalAddressSpace())), + builder_->CreateInBoundsGEP(retval_array, zero_array_index_list)); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), + builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateRet(ConstInt32(0)); + } +} + std::unique_ptr CodeGenLLVM::Finish() { this->AddStartupFunction(); for (size_t i = 0; i < link_modules_.size(); ++i) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 78eb5e2dcac7..71583708da2c 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -98,6 +98,18 @@ class CodeGenLLVM : public ExprFunctor, * \param mod The module to be linked. */ void AddLinkModule(std::unique_ptr&& mod); + /*! + * \brief Link parameters into the module so they don't need to be supplied at runtime. + * Parameters can be linked into the module so that the generated code is easier to use, or so + * that RAM space doesn't need to be allocated for them. This function adds the given parameters + * to the generated LLVM module. + * \param storage_id_offset Offset added to the index of each entry in params_by_sid to form the + * storage_id of that parameter. Storage ids for parameters are expected to be contiguous. + * \param params_by_sid Array of NDArray. Each entry is a parameter. The index of the array (added + * to sid_offset) is the storage_id of the param. + * \param param_names Array containing the name for each param in params_by_sid. + */ + void LinkParameters(const Map params); /*! * \brief Create Value for expression e * \param e The expression to be created value for. diff --git a/src/target/llvm/codegen_params.cc b/src/target/llvm/codegen_params.cc new file mode 100644 index 000000000000..694be5621606 --- /dev/null +++ b/src/target/llvm/codegen_params.cc @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_params.cc + */ +#ifdef TVM_LLVM_VERSION + +#include "codegen_params.h" + +#include +#include +#include + +namespace tvm { +namespace codegen { + +template +struct LLVMConstantGetter { + static llvm::Constant* getElement(llvm::Type* ty, T t); +}; + +template +struct LLVMConstantGetter< + T, std::enable_if_t<(std::is_integral::value && std::is_signed::value)>> { + static llvm::Constant* getElement(llvm::Type* ty, T t) { + return llvm::ConstantInt::getSigned(ty, t); + } +}; + +template +struct LLVMConstantGetter< + T, std::enable_if_t<(std::is_integral::value && !std::is_signed::value)>> { + static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantInt::get(ty, t); } +}; + +template +struct LLVMConstantGetter::value>> { + static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantFP::get(ty, t); } +}; + +template ::value>> +void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_elements, + std::vector* elements) { + elements->resize(num_elements, nullptr); + std::transform(static_cast(tensor_data), static_cast(tensor_data) + num_elements, + elements->begin(), + [&](T t) { return LLVMConstantGetter::getElement(element_type, t); }); +} + +llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr) { + llvm::Type* element_type = nullptr; + + auto arr_type = arr.DataType(); + CHECK(arr.IsContiguous()) << "CodegenParams: only support contiguous arrays"; + CHECK_EQ(arr->ctx.device_type, kDLCPU) << "CodegenParams: only support contiguous arrays"; + CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw " + << arr_type.lanes(); + + auto shape = arr.Shape(); + int num_elements = 1; + for (auto shape_elem : shape) { + num_elements *= shape_elem; + } + + std::vector elements; + + switch (arr_type.code()) { + case runtime::DataType::kInt: + CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) + << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " + << arr_type.bits() << "-bit array"; + element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); + + switch (arr_type.bits()) { + case 8: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 16: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 32: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 64: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + default: + ICHECK(false) << "should not get here"; + break; + } + break; + + case runtime::DataType::TypeCode::kUInt: + CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) + << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " + << arr_type.bits() << "-bit array"; + element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); + + switch (arr_type.bits()) { + case 8: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 16: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 32: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 64: + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + default: + ICHECK(false) << "should not get here"; + break; + } + break; + + case runtime::DataType::TypeCode::kFloat: + switch (arr_type.bits()) { + case 16: + // NOTE: float16 is treated as uint16_t. + element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 32: + element_type = llvm::Type::getFloatTy(*ctx); + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + case 64: + element_type = llvm::Type::getDoubleTy(*ctx); + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + break; + default: + CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " + << arr_type.bits() << "-bit array"; + break; + } + break; + + case runtime::DataType::TypeCode::kBFloat: + CHECK(arr_type.bits() == 16) + << "CodegenParams: only support 16-bit bfloat; saw " << arr_type.bits() << "-bit array"; + element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); + BuildLLVMVector(element_type, arr->data, num_elements, &elements); + + default: + CHECK(false) << "Data type not supported"; + } + + return llvm::cast(llvm::ConstantArray::get( + llvm::ArrayType::get(element_type, num_elements), llvm::ArrayRef(elements))); +} + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_params.h b/src/target/llvm/codegen_params.h new file mode 100644 index 000000000000..771bc201f7aa --- /dev/null +++ b/src/target/llvm/codegen_params.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_params.h + */ + +#ifndef TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ +#define TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ + +#include +#include + +#include "llvm_common.h" + +namespace tvm { +namespace codegen { + +/*! + * \brief Convert an NDArray to an LLVM array of constants. + * + * The supplied NDArray is flattened, and each element is converted to the appropriate LLVM type. + * + * \param ctx LLVM context used to create the various primitive datatypes. + * \param arr NDArray to convert. + * \return LLVM array containing the array data. + */ +llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 569082022852..73a3594427d3 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -32,6 +32,7 @@ #include "../../runtime/file_utils.h" #include "../../runtime/library_module.h" +#include "../func_registry_generator.h" #include "codegen_blob.h" #include "codegen_llvm.h" #include "llvm_common.h" @@ -199,7 +200,21 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::vector funcs; std::string entry_func; + Map linked_params; + bool found_linked_params = false; + bool could_have_linked_params = target->GetAttr("link-params").value_or(Bool(false)); for (auto kv : mod->functions) { + if (could_have_linked_params && + kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { + Map attrs_dict = + Downcast>(kv.second->attrs->dict); + CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end()) + << "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!"; + linked_params = + Downcast>(attrs_dict[::tvm::tir::attr::kLinkedParams]); + found_linked_params = true; + continue; + } ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { @@ -209,7 +224,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { } funcs.push_back(f); } - ICHECK_NE(funcs.size(), 0U); + ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); @@ -222,6 +237,9 @@ class LLVMModuleNode final : public runtime::ModuleNode { cg->AddMainFunction(entry_func); } + if (found_linked_params) { + cg->LinkParameters(linked_params); + } module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, LLVMTargetToString(target))); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 6ae11f4f9af8..0a19fc1399b7 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -23,6 +23,8 @@ #include "codegen_c_host.h" #include +#include +#include #include #include @@ -31,6 +33,7 @@ #include "../../support/str_escape.h" #include "../build_common.h" #include "../func_registry_generator.h" +#include "codegen_params.h" namespace tvm { namespace codegen { @@ -57,6 +60,48 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) { CodeGenC::AddFunction(f); } +void CodeGenCHost::LinkParameters(Map params) { + PrintFuncPrefix(); + stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param + << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " + << "int* out_ret_tcode, void* resource_handle) {\n"; + ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param), + tvm::runtime::symbol::tvm_lookup_linked_param) + << "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param; + stream << " switch (((int64_t*) args)[0]) {\n" + << " default:\n" + << " out_ret_tcode[0] = " << kTVMNullptr << ";\n" + << " return 0;\n"; + + function_names_.emplace_back(tvm::runtime::symbol::tvm_lookup_linked_param); + for (auto kv : params) { + decl_stream << "\n" + << "#ifdef __cplusplus\n" + << "extern \"C\" {\n" + << "#endif\n" + << "static const "; + int64_t num_elements = 1; + for (int64_t dim : kv.second->param.Shape()) { + num_elements *= dim; + } + PrintType(kv.second->param.DataType(), decl_stream); + decl_stream << " " << ::tvm::runtime::symbol::tvm_param_prefix << kv.first << "[" + << num_elements << "] = {\n"; + NDArrayDataToC(kv.second->param, 4, decl_stream); + decl_stream << "};\n" + << "#ifdef __cplusplus\n" + << "} // extern \"C\"\n" + << "#endif\n"; + stream << " case " << kv.second->id << ":\n" + << " ((uint64_t*)out_ret_value)[0] = (uint64_t) (uintptr_t) " + << ::tvm::runtime::symbol::tvm_param_prefix << kv.first << ";\n" + << " out_ret_tcode[0] = " << kTVMOpaqueHandle << ";\n" + << " return 0;\n"; + } + stream << " }\n" + << "}\n"; +} + void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*) stream << "#ifdef __cplusplus\n" << "extern \"C\"\n" @@ -307,12 +352,31 @@ runtime::Module BuildCHost(IRModule mod, Target target) { CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, target->str()); + Map linked_params; + bool found_linked_params = false; + bool could_have_linked_params = target->GetAttr("link-params").value_or(Bool(false)); for (auto kv : mod->functions) { + if (could_have_linked_params && + kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { + Map attrs_dict = Downcast>(kv.second->attrs->dict); + CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end()) + << "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!"; + linked_params = + Downcast>(attrs_dict[::tvm::tir::attr::kLinkedParams]); + found_linked_params = true; + continue; + } + ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto f = Downcast(kv.second); cg.AddFunction(f); } + if (could_have_linked_params) { + ICHECK(found_linked_params) << "-link-params given but none found"; + cg.LinkParameters(linked_params); + } + if (target->GetAttr("system-lib").value_or(Bool(false))) { ICHECK_EQ(target->GetAttr("runtime").value_or(""), "c") << "c target only supports generating C runtime SystemLibs"; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 1bf378be1422..b54b6fbfcfeb 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -42,6 +42,9 @@ class CodeGenCHost final : public CodeGenC { void AddFunction(const PrimFunc& f); + /*! \brief Add linked parameters, if they are present. */ + void LinkParameters(Map params); + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintFuncPrefix() final; // NOLINT(*) void PrintFinalReturn() final; // NOLINT(*) diff --git a/src/target/source/codegen_params.cc b/src/target/source/codegen_params.cc new file mode 100644 index 000000000000..cc7695abfd25 --- /dev/null +++ b/src/target/source/codegen_params.cc @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_params.cc + */ + +#include "codegen_params.h" + +#include + +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +/*! \brief maximum line length of generated parameters, including indent. */ +static constexpr const int kMaxLineLength = 80; + +static int ComputeNumElementsPerRow(int one_element_size_bytes, int indent_chars) { + if (one_element_size_bytes > kMaxLineLength - indent_chars) { + return 1; + } + // When multiple elements fit per line, divide the available space by the size of one element, + // and return the largest power of 2 less than the result. Using power-of-2-sized elements allows + // for easily traversing the generated code. + int elements_per_row = (kMaxLineLength - indent_chars) / one_element_size_bytes; + + // Implementation of fls. Iteratively clear the LSB until one bit remains. + while ((elements_per_row & (elements_per_row - 1)) > 0) { + elements_per_row &= elements_per_row - 1; + } + return elements_per_row; +} + +template ::value>> +void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::ostream& os) { + int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */); + if (std::is_signed::value) { + one_element_size_bytes += 1; // sign character + if (sizeof(T) == 64 / 8) { + one_element_size_bytes += 2; // "LL" + } + } else { + if (sizeof(T) == 64 / 8) { + one_element_size_bytes += 3; // "ULL" + } + } + + int elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); + std::string indent_str(indent_chars, ' '); + + for (size_t i = 0; i < num_elements; i++) { + if ((i % elements_per_row) == 0) { + if (i != 0) { + os << std::endl; + } + os << indent_str; + } + int64_t elem = static_cast(data)[i]; + if (std::is_signed::value) { + uint64_t to_print; + if (elem < 0) { + os << "-"; + to_print = -elem; + } else { + os << "+"; + to_print = elem; + } + os << "0x" << std::setw(sizeof(T) * 8 / 4) << static_cast(to_print); + if (sizeof(T) == 64 / 8) { + os << "LL"; + } + } else { + os << "0x" << std::setw(sizeof(T) * 8 / 4) << static_cast(elem); + if (sizeof(T) == 64 / 8) { + os << "ULL"; + } + } + if (i < num_elements - 1) { + os << ", "; + } + } + + if ((num_elements % elements_per_row) != 0) { + os << "\n"; + } +} + +template ::value>> +void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, std::ostream& os) { + // Floats and doubles are printed as hex but casted. + int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */) + 1 /* sign */ + + 1 /* decimal point */ + 1 /* exponent sign */; + if (sizeof(T) == 64 / 8) { + one_element_size_bytes += 2; /* 4 decimal digits in exponent, relative to bits / 4 */ + } else if (sizeof(T) == 32 / 8) { + one_element_size_bytes += 1; /* extra decimal digit in exponent, relative to bits / 4 */ + } + + int elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); + std::string indent_str(indent_chars, ' '); + + std::stringstream ss; + if (std::is_signed::value) { + ss.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific, + std::ios::basefield | std::ios::showbase | std::ios::floatfield); + } else { + ss.setf(std::ios::hex | std::ios::fixed | std::ios::scientific, + std::ios::basefield | std::ios::showbase | std::ios::floatfield); + } + for (size_t i = 0; i < num_elements; i++) { + if ((i % elements_per_row) == 0) { + if (i != 0) { + os << std::endl; + } + os << indent_str; + } + + T elem = static_cast(data)[i]; + if (std::isinf(elem)) { + // C99 standard. + os << (elem < 0 ? "-" : " ") << std::setw(one_element_size_bytes - 1) << "INFINITY"; + } else if (std::isnan(elem)) { + // GNU extension, implemenatation-dependent. + os << std::setw(one_element_size_bytes) << "NAN"; + } else { + ss << elem; + os << std::setw(one_element_size_bytes) << ss.str(); + ss.str(""); + } + if (i < num_elements - 1) { + os << ", "; + } + } + + if ((num_elements % elements_per_row) != 0) { + os << "\n"; + } +} + +void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os) { + auto arr_type = arr.DataType(); + CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw " + << arr_type.lanes(); + + auto shape = arr.Shape(); + int num_elements = 1; + for (auto shape_elem : shape) { + num_elements *= shape_elem; + } + + auto old_fmtflags = os.flags(); + os.setf(std::ios::internal | std::ios::hex, + std::ios::adjustfield | std::ios::basefield | std::ios::showbase); + os.fill('0'); + switch (arr_type.code()) { + case runtime::DataType::kInt: + CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) + << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " + << arr_type.bits() << "-bit array"; + if (arr_type.bits() == 8) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 16) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 32) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 64) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else { + CHECK(false) << "should not get here"; + } + break; + + case runtime::DataType::TypeCode::kUInt: + CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) + << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " + << arr_type.bits() << "-bit array"; + + if (arr_type.bits() == 8) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 16) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 32) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 64) { + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else { + CHECK(false) << "should not get here"; + } + break; + + case runtime::DataType::TypeCode::kFloat: { + os.fill(' '); + os.setf(std::ios::left, std::ios::adjustfield); + if (arr_type.bits() == 16) { + // NOTE: print types not widely supported by C as uint16_t. + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 32) { + PrintFloatingPointArray(arr->data, num_elements, indent_chars, os); + } else if (arr_type.bits() == 64) { + PrintFloatingPointArray(arr->data, num_elements, indent_chars, os); + } else { + CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " + << arr_type.bits() << "-bit array"; + } + break; + } + + case runtime::DataType::TypeCode::kBFloat: { + // NOTE: print types not widely supported by C as uint16_t. + CHECK(arr_type.bits() == 16) + << "CodegenParams: only support generating 16-bit bfloat params; saw " << arr_type.bits() + << "-bit array"; + PrintIntegralArray(arr->data, num_elements, indent_chars, os); + break; + } + + default: + CHECK(false) << "Data type not supported"; + } + + os.flags(old_fmtflags); +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/codegen_params.h b/src/target/source/codegen_params.h new file mode 100644 index 000000000000..cc126c767c58 --- /dev/null +++ b/src/target/source/codegen_params.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_params.h + */ + +#ifndef TVM_TARGET_SOURCE_CODEGEN_PARAMS_H_ +#define TVM_TARGET_SOURCE_CODEGEN_PARAMS_H_ + +#include + +#include + +namespace tvm { +namespace codegen { + +/*! + * \brief Write a C representation of arr to os. + * + * This function generates a comma-separated, indented list of C integer listeals suitable for use + * in an initializer. The NDArray is flattened and then the list is produced element by element. + * For the int16_t NDArray [-3, -2, -1, 0, 1, 2, 3, ...], and indent_chars = 4, the following output + * is produced: + * -0x0003, -0x0002, -0x0001, +0x0000, +0x0001, +0x0002, +0x0003 + * + * \param arr The array to generate + * \param indent_chars Number of chars to indent + * \param os Output stream where the array data should be written. + */ +void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CODEGEN_PARAMS_H_ diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 6bef8b3c5cd7..903c3dcfefb5 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -213,10 +213,12 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mfloat-abi") .add_attr_option("system-lib") .add_attr_option("runtime") + .add_attr_option("link-params", Bool(false)) .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("system-lib") + .add_attr_option("link-params", Bool(false)) .add_attr_option("runtime") .add_attr_option("mcpu") .add_attr_option("march") diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index ef7f4f8e16dd..101d80a52ea1 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -28,6 +28,13 @@ namespace tvm { namespace tir { +LinkedParam::LinkedParam(int64_t id, ::tvm::runtime::NDArray param) { + auto n = make_object(); + n->id = id; + n->param = param; + data_ = std::move(n); +} + // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, Map buffer_map, DictAttrs attrs, Span span) { diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 3d528f821059..a422f12b04d7 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -147,8 +147,9 @@ TEST(TargetCreation, DeduplicateKeys) { ICHECK_EQ(target->keys.size(), 2U); ICHECK_EQ(target->keys[0], "cpu"); ICHECK_EQ(target->keys[1], "arm_cpu"); - ICHECK_EQ(target->attrs.size(), 1U); + ICHECK_EQ(target->attrs.size(), 2U); ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); + ICHECK_EQ(target->GetAttr("link-params"), false); } int main(int argc, char** argv) { diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 3b5471d0bb8b..3d6923342652 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -28,6 +28,7 @@ import tvm import tvm.relay +import tvm.testing from tvm.topi.utils import get_const_tuple from tvm.topi.testing import conv2d_nchw_python diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py new file mode 100644 index 000000000000..7b6910b0ea57 --- /dev/null +++ b/tests/python/unittest/test_link_params.py @@ -0,0 +1,408 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import collections +import ctypes +import json +import os +import re +import struct +import sys + +import numpy as np +import pytest + +import tvm +import tvm.relay +import tvm.testing +from tvm.contrib import utils + + +INPUT_SHAPE = (1, 3, 16, 16) + + +KERNEL_SHAPE = (3, 3, 3, 3) + + +# The data types that are linkable. +LINKABLE_DTYPES = ( + [f"uint{b}" for b in (8, 16, 32, 64)] + + [f"int{b}" for b in (8, 16, 32, 64)] + + ["float32", "float64"] +) + + +def dtype_info(dtype): + """Lookup numpy type info for the given string dtype (of LINKABLE_DTYPES above).""" + if "int" in dtype: + return np.iinfo(getattr(np, dtype)) + else: + return np.finfo(getattr(np, dtype)) + + +# Note: for debugging, set this to an integer (i.e. 1.0). Then all "random" tensors will become +# predictable +RANDOM_TENSOR_START = None + + +def _make_random_tensor(dtype, shape): + """Create a random test tensor with given shape and dtype.""" + global RAND_SEED + if RANDOM_TENSOR_START is not None: + to_return = np.arange( + RANDOM_TENSOR_START, RANDOM_TENSOR_START + np.prod(shape), dtype=dtype + ).reshape(shape) + RAND_SEED += np.prod(shape) + return to_return + + dinfo = dtype_info(dtype) + if "int" in dtype: + return np.random.randint(dinfo.min, dinfo.max, shape, dtype=dtype) + else: + to_return = np.random.uniform(0, dinfo.max, shape).astype(dtype) + np.reshape(to_return, np.prod(shape))[::2] *= -1 + return to_return + + +def _lookup_sid(graph, name): + """Lookup the storage id of a named parameter. + + Arguments + --------- + graph : dict + Parsed JSON graph. + + name : str + Name of the tensor parameter to lookup. + + Returns + ------- + int : + The storage_id of the parameter. + """ + num_outputs_seen = 0 + for i, n in enumerate(graph["nodes"]): + if n["name"] == name: + print("sid", name, graph["attrs"]["storage_id"][1], num_outputs_seen) + return graph["attrs"]["storage_id"][1][num_outputs_seen] + else: + if "attrs" in n and "num_outputs" in n["attrs"]: + num_outputs_seen += int(n["attrs"]["num_outputs"]) + else: + num_outputs_seen += 1 + + raise KeyError(f"no such param: {name}") + + +def _get_ctypes_dtype(dt): + """Return a ctypes c_* datatype given a string data type.""" + if "int" in dt: + return getattr(ctypes, f"c_{dt}") + elif dt == "float32": + return ctypes.c_float + elif dt == "float64": + return ctypes.c_double + else: + assert False, f"unknown dtype: {dt}" + + +def _verify_linked_param(dtype, lib, mod, graph, name): + """Directly read memory from the linked library to verify the linked parameter is correct.""" + sid = _lookup_sid(graph, name) + # NOTE: query_imports=True because when loading a module from disk (i.e. for C backend), + # a GraphRuntimeFactory module is created instead of the module itself. + param_ptr = mod.get_function("_lookup_linked_param", True)(sid) + gen_param = lib.params[name] + arr_data = (_get_ctypes_dtype(dtype) * np.prod(gen_param.shape)).from_address(param_ptr.value) + arr = np.ndarray(shape=gen_param.shape, dtype=gen_param.dtype, buffer=arr_data, order="C") + if "int" in gen_param.dtype: + np.testing.assert_equal(gen_param.asnumpy(), arr) + else: + np.testing.assert_allclose(gen_param.asnumpy(), arr) + return dtype == gen_param.dtype + + +def _make_mod_and_params(dtype): + """Create a Relay module and parameters to test the given datatype.""" + param_decls = collections.OrderedDict() + param_init = {} + + def _add_decl(name, dtype): + param_decls[name] = f"%{name} : Tensor[{KERNEL_SHAPE}, {dtype}]" + param_init[name] = _make_random_tensor(dtype, KERNEL_SHAPE) + + # Add several parameters so that the number of parameters + _add_decl(f"{dtype}_a", dtype) + _add_decl(f"{dtype}_b", dtype) + + mod_lines = [ + '#[version = "0.0.5"]', + f"def @main(%rand_input : Tensor[{INPUT_SHAPE}, {dtype}], { ', '.join(param_decls.values()) } ) {{", + # This program ensures that GraphPlanMemory alternates between the same two storage IDs for a + # while. In doing this, it ensures that param %{dtype}_b will be placed into the graph at an + # index unequal to its storage_id. This ensures that GraphRuntimeCodegen encodes the storage_id + # and not the parameter index into the graph. + ( + f' %0 = nn.conv2d(%rand_input, %{dtype}_a, data_layout="NCHW", kernel_layout="OIHW", ' + f'kernel_size=[3, 3], out_dtype="{dtype}");' + ), + ( + f' %1 = nn.conv2d(%0, %{dtype}_a, data_layout="NCHW", kernel_layout="OIHW", ' + f'kernel_size=[3, 3], out_dtype="{dtype}");' + ), + ( + f' %2 = nn.conv2d(%1, %{dtype}_a, data_layout="NCHW", kernel_layout="OIHW", ' + f'kernel_size=[3, 3], out_dtype="{dtype}");' + ), + ( + f' %3 = nn.conv2d(%2, %{dtype}_b, data_layout="NCHW", kernel_layout="OIHW", ' + f'kernel_size=[3, 3], out_dtype="{dtype}");' + ), + " %3", + "}", + ] + + mod = tvm.parser.fromtext("\n".join(mod_lines)) + return mod, param_init + + +@tvm.testing.requires_llvm +def test_llvm_link_params(): + for dtype in LINKABLE_DTYPES: + mod, param_init = _make_mod_and_params(dtype) + rand_input = _make_random_tensor(dtype, INPUT_SHAPE) + main_func = mod["main"] + target = "llvm --runtime=c --system-lib --link-params" + with tvm.transform.PassContext(opt_level=3): + lib = tvm.relay.build(mod, target, params=param_init) + assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded + + print("graph", lib.graph_json) + graph = json.loads(lib.graph_json) + for p in lib.params: + _verify_linked_param(dtype, lib, lib.lib, graph, p) or found_one + + # Wrap in function to explicitly deallocate the runtime. + def _run_linked(lib): + graph_json, mod, _ = lib + graph_rt = tvm.contrib.graph_runtime.create(graph_json, mod, tvm.cpu(0)) + graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. + graph_rt.run() + return graph_rt.get_output(0) + + linked_output = _run_linked(lib) + + with tvm.transform.PassContext(opt_level=3): + lib = tvm.relay.build(mod, "llvm --system-lib", params=param_init) + + def _run_unlinked(lib): + graph_json, mod, lowered_params = lib + graph_rt = tvm.contrib.graph_runtime.create(graph_json, mod, tvm.cpu(0)) + graph_rt.set_input("rand_input", rand_input, **lowered_params) + graph_rt.run() + return graph_rt.get_output(0) + + unlinked_output = _run_unlinked(lib) + + if "int" in dtype: + np.testing.assert_equal(unlinked_output.asnumpy(), linked_output.asnumpy()) + else: + np.testing.assert_allclose(unlinked_output.asnumpy(), linked_output.asnumpy()) + + +def _get_c_datatype(dtype): + """Translate LINKABLE_DTYPES element to c datatype.""" + if "int" in dtype: + return f"{dtype}_t" + elif dtype == "float32": + return "float" + elif dtype == "float64": + return "double" + else: + assert False, f"unknown dtype {dtype}" + + +def _format_c_value(dtype, width, x): + if "int" in dtype: + hex_formatstr = f'{{:{"+" if dtype.startswith("int") else ""}#0{width}x}}' + return hex_formatstr.format(x) + elif "float" in dtype: + to_ret = float(x).hex() + if "inf" in to_ret: + return ("-" if x < 0 else "") + "INFINITY" + elif "nan" in to_ret: + return "NAN" + + before, after = to_ret.split("p") + return f'{before.rstrip("0")}p{after}' + else: + assert False, f"don't know dtype {dtype}" + + +HEX_NUM_RE = re.compile(r"[+\-]?(?:(?:0x[0-9A-Fa-f.p+-]+)|(?:INFINITY)|(?:NAN))") + + +def test_c_link_params(): + temp_dir = utils.tempdir() + for dtype in LINKABLE_DTYPES: + mod, param_init = _make_mod_and_params(dtype) + rand_input = _make_random_tensor(dtype, INPUT_SHAPE) + main_func = mod["main"] + target = "c --link-params" + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lib = tvm.relay.build(mod, target, params=param_init) + assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded + + src = lib.lib.get_source() + lib.lib.save("test.c", "cc") + c_dtype = _get_c_datatype(dtype) + src_lines = src.split("\n") + param = lib.params["p0"].asnumpy().reshape(np.prod(KERNEL_SHAPE)) + param_def = f"static const {c_dtype} __tvm_param__p0[{np.prod(param.shape)}] = {{" + for i, line in enumerate(src_lines): + if line == param_def: + i += 1 + break + else: + assert False, f'did not find parameter definition "{param_def}":\n{src}' + + cursor = 0 + width = dtype_info(dtype).bits // 4 + 2 + if dtype.startswith("int"): + width += 1 # Account for sign + + while "};" not in src_lines[i]: + for match in HEX_NUM_RE.finditer(src_lines[i]): + assert match.group() == _format_c_value(dtype, width, param[cursor]), ( + f'p0 byte {cursor}: want "{_format_c_value(dtype, width, param[cursor])}" got ' + f'"{match.group(0)}"; full p0 follows:\n{src}' + ) + cursor += 1 + i += 1 + + assert cursor == np.prod(param.shape) + temp = utils.tempdir() + + # Need a unique name per library to avoid dlopen caching the lib load. + lib_path = temp_dir.relpath(f"test-{dtype}-linked.so") + lib["remove_params"]().export_library(lib_path) + lib_mod = tvm.runtime.load_module(lib_path) + + # lib_mod = lib_factory['default']() + graph = json.loads(lib.graph_json) + for p in lib.params: + _verify_linked_param(dtype, lib, lib_mod, graph, p) + + # Wrap in function to explicitly deallocate the runtime. + def _run_linked(lib_mod): + graph_rt = tvm.contrib.graph_runtime.GraphModule(lib_mod["default"](tvm.cpu(0))) + graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. + graph_rt.run() + + return graph_rt.get_output(0) + + linked_output = _run_linked(lib_mod) + + linked_params = lib.params + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lib = tvm.relay.build(mod, "c", params=param_init) + _, _, params = lib + # Need a unique name per library to avoid dlopen caching the lib load. + lib_path = temp_dir.relpath(f"test-{dtype}-unlinked.so") + lib.export_library(lib_path) + lib_mod = tvm.runtime.load_module(lib_path) + + def _run_unlinked(lib_mod): + graph_rt = tvm.contrib.graph_runtime.GraphModule(lib_mod["default"](tvm.cpu(0))) + graph_rt.set_input("rand_input", rand_input, **params) + graph_rt.run() + return graph_rt.get_output(0) + + unlinked_output = _run_unlinked(lib_mod) + + if "int" in dtype: + np.testing.assert_equal(unlinked_output.asnumpy(), linked_output.asnumpy()) + else: + np.testing.assert_allclose(unlinked_output.asnumpy(), linked_output.asnumpy()) + + +@tvm.testing.requires_micro +def test_crt_link_params(): + import tvm.micro + + for dtype in LINKABLE_DTYPES: + mod, param_init = _make_mod_and_params(dtype) + rand_input = _make_random_tensor(dtype, INPUT_SHAPE) + main_func = mod["main"] + target = "c -mcpu=native --system-lib --runtime=c --link-params" + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + graph_json, lib, params = tvm.relay.build(mod, target, params=param_init) + assert set(params.keys()) == {"p0", "p1"} # NOTE: op folded + + workspace = tvm.micro.Workspace() + compiler = tvm.micro.DefaultCompiler(target=target) + opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, "host")) + opts["bin_opts"]["ldflags"].append("-DTVM_HOST_USE_GRAPH_RUNTIME_MODULE") + + micro_binary = tvm.micro.build_static_runtime( + # the x86 compiler *expects* you to give the exact same dictionary for both + # lib_opts and bin_opts. so the library compiler is mutating lib_opts and + # the binary compiler is expecting those mutations to be in bin_opts. + # TODO(weberlo) fix this very bizarre behavior + workspace, + compiler, + lib, + lib_opts=opts["bin_opts"], + bin_opts=opts["bin_opts"], + extra_libs=[ + os.path.join(tvm.micro.CRT_ROOT_DIR, m) + for m in ("graph_runtime_module", "graph_runtime") + ], + ) + + flasher_kw = { + "debug": False, + } + flasher = compiler.flasher(**flasher_kw) + with tvm.micro.Session(binary=micro_binary, flasher=flasher) as sess: + rpc_lib = sess.get_system_lib() + graph_rt = tvm.contrib.graph_runtime.create(graph_json, rpc_lib, sess.context) + + # NOTE: not setting params here. + graph_rt.set_input("rand_input", rand_input) + graph_rt.run() + linked_output = graph_rt.get_output(0).asnumpy() + + with tvm.transform.PassContext(opt_level=3): + lib = tvm.relay.build(mod, "llvm --system-lib", params=param_init) + + def _run_unlinked(lib): + graph_json, mod, lowered_params = lib + graph_rt = tvm.contrib.graph_runtime.create(graph_json, mod, tvm.cpu(0)) + graph_rt.set_input("rand_input", rand_input, **lowered_params) + graph_rt.run() + return graph_rt.get_output(0).asnumpy() + + unlinked_output = _run_unlinked(lib) + + if "int" in dtype: + np.testing.assert_equal(unlinked_output, linked_output) + else: + np.testing.assert_allclose(unlinked_output, linked_output) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 3599493a74cb..162481bfdb6e 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import collections +import ctypes +import json import tvm import tvm.testing from tvm import te from tvm import topi -from tvm.contrib import utils, clang +from tvm.contrib import utils import numpy as np import ctypes import math