From a01a38ec77ef200a8bc2d3805eff1ce500cd5b03 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Tue, 2 Feb 2021 19:14:25 +0000 Subject: [PATCH] [AOT] Introducing AOT in TVM This change adds the code generation and minimal runtime API to use the Ahead Of Time (AOT) compilation flow. The main logic is contained in: - src/relay/backend/aot_codegen.cc Which produces a TIR PrimFunc traversing the Relay graph The runtime interface (authored by @mousius) leaves a gap for future iterations using platform-specific features from RTOS. Currently AOT runs successfully on x86 in a host OS, running these tests on micro is coming soon. This PR is based on the RFC described here: https://discuss.tvm.apache.org/t/implementing-aot-in-tvm/9206 Co-authored-by: Christopher Sidebottom Change-Id: I9f731c953231f129e1472298915dddc01788efd7 --- cmake/modules/StandaloneCrt.cmake | 7 +- include/tvm/runtime/crt/aot/tvm_backend.h | 104 +++ include/tvm/runtime/crt/aot/tvm_error.h | 68 ++ include/tvm/runtime/crt/aot/tvm_executor.h | 97 +++ include/tvm/runtime/module.h | 2 + include/tvm/tir/builtin.h | 4 + python/tvm/driver/tvmc/compiler.py | 2 +- python/tvm/micro/model_library_format.py | 33 +- .../relay/backend/graph_runtime_codegen.py | 2 +- .../relay/backend/graph_runtime_factory.py | 15 +- python/tvm/relay/build_module.py | 18 +- src/relay/backend/aot_codegen.cc | 674 ++++++++++++++++++ src/relay/backend/build_module.cc | 69 +- src/relay/backend/graph_plan_memory.cc | 4 +- src/relay/backend/graph_runtime_codegen.cc | 4 +- src/runtime/crt/aot/tvm_executor.c | 91 +++ src/runtime/crt/graph_runtime/graph_runtime.c | 1 + src/runtime/meta_data.h | 32 + src/target/metadata_module.cc | 5 +- src/target/metadata_module.h | 5 +- src/target/source/codegen_c_host.cc | 69 +- src/target/source/codegen_c_host.h | 5 +- src/target/source/codegen_source_base.h | 7 +- src/target/source/source_module.cc | 32 +- src/target/source/source_module.h | 5 +- src/target/target_kind.cc | 4 +- src/tir/op/builtin.cc | 3 + src/tir/transforms/lower_tvm_builtin.cc | 2 +- tests/cpp/relay_build_module_test.cc | 2 +- tests/cpp/utvm_runtime_standalone_test.cc | 2 +- tests/crt/aot_executor_test.cc | 199 ++++++ tests/crt/aot_memory_test.cc | 105 +++ tests/python/relay/aot/aot_test.mk | 71 ++ tests/python/relay/aot/infra.py | 213 ++++++ tests/python/relay/aot/test_crt_aot.py | 258 +++++++ .../relay/test_backend_graph_runtime.py | 2 +- tests/python/relay/test_pass_annotation.py | 2 +- tests/python/unittest/test_crt.py | 2 +- tests/python/unittest/test_link_params.py | 4 +- .../test_micro_model_library_format.py | 2 +- .../test_runtime_module_based_interface.py | 2 +- 41 files changed, 2140 insertions(+), 88 deletions(-) create mode 100644 include/tvm/runtime/crt/aot/tvm_backend.h create mode 100644 include/tvm/runtime/crt/aot/tvm_error.h create mode 100644 include/tvm/runtime/crt/aot/tvm_executor.h create mode 100644 src/relay/backend/aot_codegen.cc create mode 100644 src/runtime/crt/aot/tvm_executor.c create mode 100644 tests/crt/aot_executor_test.cc create mode 100644 tests/crt/aot_memory_test.cc create mode 100644 tests/python/relay/aot/aot_test.mk create mode 100644 tests/python/relay/aot/infra.py create mode 100644 tests/python/relay/aot/test_crt_aot.py diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index dc1b3b2665f26..ff2d3e1745967 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -40,6 +40,7 @@ if(USE_MICRO) "3rdparty/dmlc-core/include *.h -> include" "include/tvm/runtime c_*_api.h -> include/tvm/runtime" "include/tvm/runtime/crt *.h -> include/tvm/runtime/crt" + "include/tvm/runtime/crt/aot *.h -> src/runtime/crt/aot" "src/runtime/crt Makefile -> ." "src/runtime/crt/include *.h -> include" "src/runtime/crt/common *.c -> src/runtime/crt/common" @@ -48,6 +49,7 @@ if(USE_MICRO) "src/runtime/crt/host crt_config.h -> template/host" "src/runtime/crt/host *.cc -> template/host" "src/runtime/crt/memory *.c -> src/runtime/crt/memory" + "src/runtime/crt/aot *.c -> src/runtime/crt/aot" "src/runtime/crt/utvm_rpc_common *.cc -> src/runtime/crt/utvm_rpc_common" "src/runtime/crt/utvm_rpc_server *.cc -> src/runtime/crt/utvm_rpc_server" "src/runtime/minrpc *.h -> src/runtime/minrpc" @@ -135,6 +137,7 @@ if(USE_MICRO) file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*_test.cc) find_path(GTEST_INCLUDE_DIR gtest/gtest.h) find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}") + set(aot_executor_src "${standalone_crt_base}/src/runtime/crt/aot/tvm_executor.c") # Create the `crttest` target if we can find GTest. If not, we create dummy # targets that give the user an informative error message. @@ -144,7 +147,9 @@ if(USE_MICRO) string(REPLACE ".cc" "" __execname ${__srcname}) add_executable(${__execname} ${__srcpath}) list(APPEND TEST_EXECS ${__execname}) - target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host) + target_sources(${__execname} PRIVATE ${aot_executor_src}) + target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_SOURCE_DIR}/src/runtime/crt/host) + target_include_directories(${__execname} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/src/runtime/crt/aot) target_compile_options(${__execname} PRIVATE -pthread) target_link_libraries(${__execname} ${cmake_crt_libraries} ${GTEST_LIB} pthread) set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) diff --git a/include/tvm/runtime/crt/aot/tvm_backend.h b/include/tvm/runtime/crt/aot/tvm_backend.h new file mode 100644 index 0000000000000..1875cea10a6b8 --- /dev/null +++ b/include/tvm/runtime/crt/aot/tvm_backend.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file include/tvm/runtime/crt/aot/tvm_backend.h + * \brief Backend functions for the AOT executor + * + * These are not designed to user-facing and may change without warning + */ + +#ifndef TVM_RUNTIME_CRT_AOT_TVM_BACKEND_H_ +#define TVM_RUNTIME_CRT_AOT_TVM_BACKEND_H_ + +#include +#include + +#include "tvm_error.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! Memory alignment for allocator */ +#ifndef TVM_RUNTIME_ALLOC_ALIGNMENT +#define TVM_RUNTIME_ALLOC_ALIGNMENT 16 +#endif + +/*! The AOT runtime links staticly */ +#define TVM_DLL + +/*! + * \brief Minimal TVMValue + */ +typedef union { + int64_t v_int64; /** Currently used for parameter lookup */ + void* v_handle; /** Pointer to other values */ +} TVMValue; + +/*! + * \brief Packed function signature definition + */ +typedef int32_t(tvm_function_t)(void* args, void* arg_type_ids, int32_t num_args, + void* out_ret_value, void* out_ret_tcode, void* resource_handle); + +/*! + * \brief Workspace memory structure + */ +typedef struct { + uint8_t* next_alloc; /** Pointer to the next block of bytes to allocate */ + uint8_t* workspace; /** Pointer to start of the workspace */ + size_t workspace_size; /** Total number of bytes in the workspace */ +} tvm_workspace_t; + +/** + * \brief Backend function to allocate temporal workspace. + * + * \note The result allocated space is ensured to be aligned to TVM_RUNTIME_ALLOC_ALIGNMENT. + * \note Currently matches CRT runtime signature but this will change in future to accommodate + * memory planning + * + * \param device_type Ignored + * \param device_id Ignored + * \param nbytes The size of the space requested. + * \param dtype_code_hint Ignored + * \param dtype_bits_hint Ignored + * \return void* NULL on error, a valid pointer on success + */ +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, + int dtype_bits_hint); + +/*! + * \brief Backend function to free temporal workspace. + * + * \note Currently matches CRT runtime signature but this will change in future to accomodate memory + * planning + * + * \param ptr The result allocated space pointer. + * \param device_type Ignored + * \param device_id Ignored + * \return tvm_crt_error_t Containing any error statuses + */ +tvm_crt_error_t TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TVM_RUNTIME_CRT_AOT_TVM_BACKEND_H_ diff --git a/include/tvm/runtime/crt/aot/tvm_error.h b/include/tvm/runtime/crt/aot/tvm_error.h new file mode 100644 index 0000000000000..4b90c1afd9fe6 --- /dev/null +++ b/include/tvm/runtime/crt/aot/tvm_error.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file include/tvm/runtime/crt/aot/tvm_error.h + * \brief Defines a subset of error codes returned by the CRT AOT executor. + */ + +#ifndef TVM_RUNTIME_CRT_AOT_TVM_ERROR_H_ +#define TVM_RUNTIME_CRT_AOT_TVM_ERROR_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#define TVM_CRT_ERROR_CATEGORY_Pos 8 +#define TVM_CRT_ERROR_CATEGORY_Msk (0xff << TVM_CRT_ERROR_CATEGORY_Pos) +#define TVM_CRT_ERROR_CODE_Pos 0 +#define TVM_CRT_ERROR_CODE_Msk (0xff << TVM_CRT_ERROR_CODE_Pos) + +#define DEFINE_TVM_CRT_ERROR(category, code) \ + (((category) << TVM_CRT_ERROR_CATEGORY_Pos) | ((code) << TVM_CRT_ERROR_CODE_Pos)) +typedef enum { + kTvmErrorCategoryPlatform = 5, + kTvmErrorCategoryFunctionCall = 8, +} tvm_crt_error_category_t; + +typedef enum { + kTvmErrorNoError = 0, + + // Platform + kTvmErrorPlatformCheckFailure = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 0), + kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1), + kTvmErrorPlatformShutdown = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 2), + kTvmErrorPlatformNoMemory = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 3), + kTvmErrorPlatformTimerBadState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 4), + + // Function Calls - common problems encountered calling functions. + kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0), + kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1), + kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2), + + // System errors are always negative integers; this mask indicates presence of a system error. + // Cast tvm_crt_error_t to a signed integer to interpret the negative error code. + kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)), +} tvm_crt_error_t; + +#ifdef __cplusplus +} +#endif + +#endif // TVM_RUNTIME_CRT_AOT_TVM_ERROR_H_ diff --git a/include/tvm/runtime/crt/aot/tvm_executor.h b/include/tvm/runtime/crt/aot/tvm_executor.h new file mode 100644 index 0000000000000..efa5e7b06750f --- /dev/null +++ b/include/tvm/runtime/crt/aot/tvm_executor.h @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file include/tvm/runtime/crt/aot/tvm_executor.h + * \brief TVM Executor for the Ahead-of-Time Runtime + * + * AOT models are described by the TVM model descriptor format + * which can be passed to tvm_runtime_run. These descriptors will be + * generated by the AOT compilation process. This can optionally be + * augmented with platform specific context to be passed to the TVM + * operators. + * + * Example: + * extern tvm_model_t my_network; + * int main() { + * void* data = get_data(); + * void* output[4] = {0, 0, 0, 0}; + * void* inputs = {data}; + * void* outputs = {output}; + * tvm_context_t my_context = { + * .driver = ...; + * }; + * tvm_runtime_run( + * &my_network, + * inputs, + * outputs + * &my_context + * ); + * return 0; + * } + */ + +#ifndef TVM_RUNTIME_CRT_AOT_TVM_EXECUTOR_H_ +#define TVM_RUNTIME_CRT_AOT_TVM_EXECUTOR_H_ + +#include + +#include "tvm_backend.h" +#include "tvm_error.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief Context information for future integrations + * which is passed through to the operators. + * + * \note Can be used for drivers and platform specific information. + */ +typedef struct { +} tvm_context_t; + +/*! + * \brief TVM Model descriptor to describe the + * model to the runtime. + */ +typedef struct { + uint32_t num_input_tensors; /** Number of expected input tensors */ + uint32_t num_output_tensors; /** Number of expected output tensors */ + tvm_function_t* run_func; /** Generated model function, called through tvm_runtime_run */ + tvm_workspace_t* workspace; /** Memory workspace for the model to use */ +} tvm_model_t; + +/*! + * \brief Main entry point for + * \param model Model descriptor structure to reference for runtime information + * \param inputs Pointer to input pointer(s) + * \param outputs Pointer to output pointer(s) + * \param context Context information to be passed through to operators + * \return tvm_status_t containing success or errors from the model run + */ +tvm_crt_error_t tvm_runtime_run(const tvm_model_t* model, void** inputs, void** outputs, + tvm_context_t* context); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TVM_RUNTIME_CRT_AOT_TVM_EXECUTOR_H_ diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 04a5cf8bf25d7..689fe6fa53fce 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -230,6 +230,8 @@ constexpr const char* tvm_module_main = "__tvm_main__"; constexpr const char* tvm_param_prefix = "__tvm_param__"; /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; +/*! \brief The main AOT executor function */ +constexpr const char* tvm_run_func_prefix = "tvm__run_func"; } // namespace symbol // implementations of inline functions. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 6a40d86b89848..e2920585a5d3e 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -343,6 +343,10 @@ TVM_DLL const Op& tvm_stack_make_array(); */ TVM_DLL const Op& tvm_call_packed(); +// This achieve the same of a packed call, but with an extern call +// directly to the operator +TVM_DLL const Op& tvm_call_unpacked(); + /*! * \brief See pesudo code * diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 83791e50f6d5a..2c2c4960a04ef 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -241,7 +241,7 @@ def compile_model( # TODO we need to update this return to use the updated graph module APIs # as these getter functions will be deprecated in the next release (@leandron) - return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps + return graph_module.get_graph(), graph_module.get_lib(), graph_module.get_params(), dumps def save_module(module_path, graph, lib, params, cross=None): diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 4ce80be647c1b..9eac962ae2207 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -73,7 +73,7 @@ def _populate_codegen_dir(mod, codegen_dir: str): dso_mod.save(file_name) -def _build_memory_map(graph_json): +def _build_memory_map(graph_str): """Build a simpler memory map from graph JSON. Parameters @@ -86,10 +86,13 @@ def _build_memory_map(graph_json): list : A list with one entry per storage id describing that memory. """ - graph = json.loads(graph_json) + memory_map = [] + if graph_str.startswith("primfn"): + return memory_map + + graph = json.loads(graph_str) seen_storage_ids = set() - memory_map = [] for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]): if storage_id in seen_storage_ids: continue @@ -132,14 +135,25 @@ def export_model_library_format(mod: graph_runtime_factory.GraphRuntimeFactoryMo Path to the .tar archive to generate. """ tempdir = utils.tempdir() + is_aot = False + for v in mod.target.values(): + if v.attrs.get("executor", "graph_runtime") == "aot": + is_aot = True + break + + runtime = ["graph"] + if is_aot: + runtime = ["aot"] + metadata = { "version": 1, "model_name": mod.libmod_name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), - "memory": _build_memory_map(mod.graph_json), + "memory": _build_memory_map(mod.graph), "target": {int(k): str(v) for k, v in mod.target.items()}, - "runtimes": ["graph"], + "runtimes": runtime, } + with open(tempdir.relpath("metadata.json"), "w") as json_f: json.dump(metadata, json_f, indent=2, sort_keys=True) @@ -156,10 +170,11 @@ def export_model_library_format(mod: graph_runtime_factory.GraphRuntimeFactoryMo with open(tempdir.relpath("relay.txt"), "w") as f: f.write(str(mod.ir_mod)) - graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) - os.makedirs(graph_config_dir_path) - with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: - f.write(mod.graph_json) + if not is_aot: + graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) + os.makedirs(graph_config_dir_path) + with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: + f.write(mod.graph) with tarfile.open(file_name, "w") as tar_f: diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 81ab4cb4de250..6581bda27c854 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -46,7 +46,7 @@ def __init__(self, mod, target): self._mod = _build_module._GraphRuntimeCodegen() self._init = self._mod["init"] self._codegen = self._mod["codegen"] - self._get_graph_json = self._mod["get_graph_json"] + self._get_graph_json = self._mod["get_graph"] self._list_params_name = self._mod["list_params_name"] self._get_param_by_name = self._mod["get_param_by_name"] self._get_irmodule = self._mod["get_irmodule"] diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index e92ae710ca0b0..6a37c5d202b7e 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -41,17 +41,18 @@ class GraphRuntimeFactoryModule: The parameters of module """ - def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params): - assert isinstance(graph_json_str, string_types) + def __init__(self, ir_mod, target, graph_str, libmod, libmod_name, params): + assert isinstance(graph_str, string_types) fcreate = get_global_func("tvm.graph_runtime_factory.create") args = [] for k, v in params.items(): args.append(k) args.append(ndarray.array(v)) + self.ir_mod = ir_mod self.target = target - self.module = fcreate(graph_json_str, libmod, libmod_name, *args) - self.graph_json = graph_json_str + self.module = fcreate(graph_str, libmod, libmod_name, *args) + self.graph = graph_str self.lib = libmod self.libmod_name = libmod_name self.params = params @@ -66,8 +67,8 @@ def export_library(self, file_name, fcompile=None, addons=None, **kwargs): def get_params(self): return self.params - def get_json(self): - return self.graph_json + def get_graph(self): + return self.graph def get_lib(self): return self.lib @@ -90,7 +91,7 @@ def __next__(self): if self.iter_cnt > 2: raise StopIteration - objs = [self.graph_json, self.lib, self.params] + objs = [self.graph, self.lib, self.params] obj = objs[self.iter_cnt] self.iter_cnt += 1 return obj diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 8e69d288df12d..9894db9c16f80 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -76,7 +76,7 @@ class BuildModule(object): def __init__(self): self.mod = _build_module._BuildModule() - self._get_graph_json = self.mod["get_graph_json"] + self._get_graph = self.mod["get_graph"] self._get_module = self.mod["get_module"] self._build = self.mod["build"] self._optimize = self.mod["optimize"] @@ -133,11 +133,11 @@ def build(self, mod, target=None, target_host=None, params=None): autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent # Get artifacts - graph_json = self.get_json() + graph = self.get_graph() mod = self.get_module() params = self.get_params() - return graph_json, mod, params + return graph, mod, params def optimize(self, mod, target=None, params=None): """ @@ -177,9 +177,9 @@ def optimize(self, mod, target=None, params=None): def _set_params(self, params): self._set_params_func(_convert_param_map(params)) - def get_json(self): + def get_graph(self): """Return the json file of the built program.""" - return self._get_graph_json() + return self._get_graph() def get_module(self): """Return the built module.""" @@ -240,8 +240,8 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" Returns ------- - graph_json : str - The json string that can be accepted by graph runtime. + graph : str + The string representation of the graph mod : tvm.Module The module containing necessary libraries. @@ -280,9 +280,9 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" with tophub_context: bld_mod = BuildModule() - graph_json, runtime_mod, params = bld_mod.build(ir_mod, target, target_host, params) + graph, runtime_mod, params = bld_mod.build(ir_mod, target, target_host, params) runtime_mod = _graph_runtime_factory.GraphRuntimeFactoryModule( - ir_mod, target, graph_json, runtime_mod, mod_name, params + ir_mod, target, graph, runtime_mod, mod_name, params ) return runtime_mod diff --git a/src/relay/backend/aot_codegen.cc b/src/relay/backend/aot_codegen.cc new file mode 100644 index 0000000000000..401334ef11cf8 --- /dev/null +++ b/src/relay/backend/aot_codegen.cc @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/graph_codegen.cc + * \brief Graph runtime codegen + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../runtime/meta_data.h" +#include "compile_engine.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +using IntegerArray = Array; +using ShapeVector = std::vector>; +using GraphAttrs = std::unordered_map; +using TargetsMap = std::unordered_map; + +/*! \brief Lowered outputs */ +struct AOTLoweredOutput { + std::string graph_tir; + Map lowered_funcs; + Array external_mods; + std::unordered_map> params; + runtime::AOTMetadata aot_metadata; +}; + +class AotReturnSidVisitor : public ExprVisitor { + public: + explicit AotReturnSidVisitor(Map> storage_device_map) + : storage_device_map_{storage_device_map}, return_sid_{-1} {} + + IntegerArray FindReturnSid(Function func) { + VisitExpr(func->body); + return return_sid_; + } + + protected: + void AssignReturnSid(Expr e) { + auto iter = storage_device_map_.find(e); + if (iter != storage_device_map_.end()) { + return_sid_ = (*iter).second[0]; + } + } + + void VisitExpr_(const ConstantNode* cn) override { + ExprVisitor::VisitExpr_(cn); + AssignReturnSid(GetRef(cn)); + } + + void VisitExpr_(const VarNode* vn) override { + ExprVisitor::VisitExpr_(vn); + AssignReturnSid(GetRef(vn)); + } + + void VisitExpr_(const CallNode* cn) override { + ExprVisitor::VisitExpr_(cn); + AssignReturnSid(GetRef(cn)); + } + + void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); } + + void VisitExpr_(const TupleNode* tn) override { + ExprVisitor::VisitExpr_(tn); + AssignReturnSid(GetRef(tn)); + } + + private: + Map> storage_device_map_; + IntegerArray return_sid_; +}; + +using TIRNetwork = tvm::Array; + +/*! \brief Code generator for graph runtime */ +class AOTCodegen : public ExprVisitor { + protected: + /*! + * \brief Utility function to allocate a DLTensor or TVMValue + * \param type the type of allocation + * \param num the number of variable to allocate on the stack + * \return PrimExpr representing the allocated object + */ + PrimExpr StackAlloca(std::string type, size_t num) { + Array args = {tir::StringImm(type), ConstInt32(num)}; + return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); + } + + /*! + * \brief Utility function to allocate memory for storage identifiers + * \param memory_size_byte size in bytes of the allocation + * \return PrimExpr representing the allocated memory + */ + PrimExpr AllocateBackendMemory(int memory_size_byte) { + // TODO(giuseros): use tir::Allocate instead of TVMBackendAllocWorkspace + // to enable unified memory planning + static const Op& op = Op::Get("tir.TVMBackendAllocWorkspace"); + return tvm::tir::Call(DataType::Handle(), op, {1, 0, memory_size_byte, 2, 8}); + } + + /*! + * \brief Utility function to convert a concrete integer to a PrimExpr. + * \param num the number to convert + * \return PrimExpr representing num + */ + inline PrimExpr ConstInt32(size_t num) { + ICHECK_LE(num, std::numeric_limits::max()); + return tir::make_const(DataType::Int(32), static_cast(num)); + } + + /*! + * \brief Return a vector of variables that represents the sids for the given Relay Expr + */ + std::vector pack_sid(Expr expr) { + Array sids = storage_device_map_[expr]; + std::vector sid_vars; + + // Note that an expression can have multiple sids associated with it + // e.g., returning multiple values from a function + for (const auto& sid : sids[0]) { + // Determine if an sid is an output buffer + int sid_int = static_cast((sid.as())->value); + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + sid_vars.push_back(main_signature_[input_vars_.size() + output_index]); + continue; + } + // Pack the sid inside the TVMValue + auto sid_array = te::Var(make_string("sid_", sid, "_value"), DataType::Handle()); + auto sid_value = sids_table_[sid]; + tvm::PrimExpr set_tensor = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrData, sid_value}); + stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); + sid_vars.push_back(sid_array); + } + return sid_vars; + } + + /*! + * \brief Utility function to return a parameter associated with an expression + * \param expr Relay Expression assicated with the parameter + * \return Variable that represents the DLTensor associated with the parameters + */ + tir::Var pack_param(Expr expr) { + // TODO(giuseros): Using call_extern to call into lookup_linked_param. This is because the + // builtin::ret is not supported yet in the c target. Once return is supported we can use + // tvm_call_packed_lowered(). + int param_sid = param_storage_ids_[reverse_params_lookup_[expr]]; + auto lookup_linked_param_fn = tir::StringImm(::tvm::runtime::symbol::tvm_lookup_linked_param); + auto param_array = te::Var(make_string("param_", param_sid, "_array"), DataType::Handle()); + + // Compose the lookup_call using a local stack + Array lookup_call; + auto param_var = te::Var(make_string("param_", param_sid, "_value"), DataType::Handle()); + auto ret_var = te::Var("ret_value", DataType::Handle()); + auto ret_code = te::Var("ret_value", DataType::Handle()); + + lookup_call.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {param_var, 0, tir::builtin::kTVMValueContent, ConstInt32(param_sid)}))); + lookup_call.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tir::builtin::call_extern(), + {lookup_linked_param_fn, param_var, 0, 0, ret_var, ret_code, 0}))); + auto ret_var_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {ret_var, 0, tir::builtin::kTVMValueContent}); + + // Set the param to the value returned by lookup_call + tvm::PrimExpr set_param_array = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {param_array, 0, tir::builtin::kArrData, ret_var_handle}); + lookup_call.push_back(tir::Evaluate(set_param_array)); + + tir::Stmt lookup_body = tir::SeqStmt(lookup_call); + + // Allocate the DLTensors on the stack + lookup_body = tir::LetStmt(param_var, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(ret_var, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(ret_code, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body); + stmts_.push_back(lookup_body); + return param_array; + } + + /*! + * brief Given an expression return the variable(s) associated with that expression + */ + std::vector find_expr(Expr arg) { + auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); + if (input_iter != input_vars_.end()) { + // Input variable + int main_index = std::distance(input_vars_.begin(), input_iter); + return {main_signature_[main_index]}; + } else if (reverse_params_lookup_.find(arg) != reverse_params_lookup_.end()) { + // Parameter of the network + return {pack_param(arg)}; + } else { + // Storage identifier (i.e., intermediate memory) + return pack_sid(arg); + } + } + + /*! + * brief Call a function with a given name + */ + void func_call(Call call, std::string func_name) { + tvm::Array args{tvm::tir::StringImm(func_name)}; + std::vector func_call_stmts; + + // Pack the inputs + for (Expr arg : call->args) { + auto var_arg = find_expr(arg); + args.push_back(var_arg[0]); + } + + auto ret_expr = Downcast(call); + + // Pack the return(s) value. A call node can produce multiple outputs + for (const auto& var : pack_sid(ret_expr)) { + args.push_back(var); + } + + // Use tvm_call_packed to execute the function + func_call_stmts.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); + tir::Stmt body = tir::SeqStmt(func_call_stmts); + stmts_.push_back(body); + } + + /*! + * brief Copy a variable to the output. This function is mainly used in edge cases + * when we want to return an input or a parameter. + */ + void copy_to_output(te::Var out, te::Var in, size_t size) { + auto retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {in, 0, tir::builtin::kArrData}); + + // Define intermediate DLTensor to load/store the data + auto tmp0 = te::Var("tmp0", DataType::Handle()); + auto tmp1 = te::Var("tmp1", DataType::Handle()); + te::Var loop_idx("i", DataType::Int(32)); + auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true()); + auto tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {out, 0, tir::builtin::kArrData}); + + // Copy the variable from the input to the output + tir::Stmt copy = tir::For( + loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, + tir::Store(tmp1, tir::Let(tmp0, retval_get, retval_i), loop_idx, tir::const_true())); + stmts_.push_back(tir::LetStmt(tmp1, tostore, copy)); + } + + /*! + * Utility function to string together different arguments + */ + template + std::string make_string(Args const&... args) { + std::ostringstream ss; + using List = int[]; + (void)List{0, ((void)(ss << args), 0)...}; + + return ss.str(); + } + + void VisitExpr_(const CallNode* op) override { + // Descend the call tree + for (auto arg : op->args) { + VisitExpr(arg); + } + + Expr expr = GetRef(op); + Function func; + if (op->op.as()) { + LOG(FATAL) << "Operators should be transformed away; try applying" + << "the fuse_ops transformation to the expression."; + } else if (op->op.as()) { + LOG(FATAL) << "Not implemented"; + } else if (op->op.as()) { + func = GetRef(op->op.as()); + } else { + LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); + } + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + LOG(FATAL) << "TVM only support calls to primitive functions " + << "(i.e functions composed of fusable operator invocations)"; + } + + auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); + auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); + Target target; + // Handle external function + if (func->GetAttr(attr::kCompiler).defined()) { + target = Target("ext_dev"); + CCacheKey key = (*pf0)(func, target); + CachedFunc ext_func = (*pf1)(compile_engine_, key); + ICHECK(ext_func.defined()) << "External function is not defined."; + UpdateConstants(func, ¶ms_); + + // Generate the TIR function call + func_call(GetRef(op), ext_func->func_name); + } + + ICHECK_GE(storage_device_map_.count(expr), 0); + auto& device_type = storage_device_map_[expr][1]; + auto call_dev_type = device_type[0]->value; + // Normal Relay Function + if (targets_.size() == 1) { + // homogeneous execution. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // heterogeneous execution. + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(call_dev_type); + } + if (targets_.count(call_dev_type) == 0) { + LOG(FATAL) << "No target is provided for device " << call_dev_name; + } + target = targets_[call_dev_type]; + } + CCacheKey key = (*pf0)(func, target); + CachedFunc lowered_func = (*pf1)(compile_engine_, key); + if (!lowered_funcs_.count(target->str())) { + lowered_funcs_[target->str()] = IRModule(Map({})); + } + lowered_funcs_[target->str()]->Update(lowered_func->funcs); + + // Generate the TIR function call + func_call(GetRef(op), lowered_func->func_name); + } + + void VisitExpr_(const VarNode* op) override { + Expr expr = GetRef(op); + + // If the Var node is an output node we need to copy the content of the variable to the output + // A Var node can only produce a single output + Array sids = storage_device_map_[expr]; + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), + static_cast((sids[0][0].as())->value)); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + auto var_expr = find_expr(expr); + copy_to_output(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]); + } + } + + void VisitExpr_(const ConstantNode* op) override { + Expr expr = GetRef(op); + size_t index = params_.size(); + std::string name = "p" + std::to_string(index); + + param_storage_ids_[name] = storage_device_map_[expr][0][0]->value; + params_[name] = op->data; + reverse_params_lookup_.Set(expr, name); + + // If the Constant node is an output node we need to copy the content of the parameter to the + // output A Var node can only produce a single output + Array sids = storage_device_map_[expr]; + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), + static_cast((sids[0][0].as())->value)); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + copy_to_output(main_signature_[input_vars_.size() + output_index], pack_param(expr), + sids[2][0]); + } + } + + void VisitExpr_(const TupleNode* op) override { + for (auto field : op->fields) { + VisitExpr(field); + } + } + + void VisitExpr_(const LetNode* op) override { + // TODO(giuseros): support Let nodes in AOT + throw std::invalid_argument("Let not yet implemented in AOT"); + } + void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } + void VisitExpr_(const OpNode* op) override { + throw std::runtime_error("can not compile op in non-eta expanded form"); + } + void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } + void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } + void VisitExpr_(const FunctionNode* op) override { + ICHECK(op->GetAttr(attr::kCompiler).defined()) + << "Only functions supported by custom codegen"; + } + void VisitExpr_(const RefCreateNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const RefReadNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const RefWriteNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const ConstructorNode* op) override { + throw std::invalid_argument("ADT constructor case not yet implemented"); + } + void VisitExpr_(const MatchNode* op) override { + throw std::invalid_argument("match case not yet implemented"); + } + + // Create the main PrimFunc to execute the graph + tir::PrimFunc CreateMainFunc(unsigned int relay_params) { + tir::Stmt body = tir::SeqStmt(stmts_); + + // Allocate the sids + std::unordered_map allocated; + + for (auto kv : storage_device_map_) { + // Only allocate sids that are needed + const bool is_input = + (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); + const bool is_param = (reverse_params_lookup_.find(kv.first) != reverse_params_lookup_.end()); + if (is_input || is_param) { + continue; + } + + for (unsigned int i = 0; i < kv.second[0].size(); i++) { + int size = kv.second[2][i]; + int sid = static_cast((kv.second[0][i].as())->value); + + if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { + continue; + } + + if (!allocated[sid]) { + body = tir::LetStmt(sids_table_[sid], AllocateBackendMemory(size), body); + } + allocated[sid] = true; + } + } + + // Define the attributes + body = tir::AttrStmt(PrimExpr(), tir::attr::device_context_type, 1, body); + body = tir::AttrStmt(PrimExpr(), tir::attr::device_context_id, 0, body); + + // Make the PrimFunc + return tir::PrimFunc(main_signature_, body, VoidType(), Map(), + DictAttrs(dict_attrs_)); + } + + protected: + /*! \brief nodes */ + /*! \brief mod */ + runtime::Module* mod_; + std::vector input_vars_; + Array main_signature_; + /*! \brief target device */ + TargetsMap targets_; + Target target_host_; + Map dict_attrs_; + + /*! + * \brief parameters (i.e. ConstantNodes found in the graph). + * These are take as inputs to the GraphRuntime. + * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be + * used to lookup the parameter. + */ + Map reverse_params_lookup_; + std::unordered_map params_; + std::unordered_map param_storage_ids_; + + /*! \brief plan memory of device result */ + Map> storage_device_map_; + std::unordered_map sids_table_; + /*! \brief lowered funcs */ + std::unordered_map lowered_funcs_; + /*! \brief name map */ + std::unordered_map name_map_; + /*! \brief compile engine */ + CompileEngine compile_engine_; + /*! \brief GraphPlanMemory module */ + runtime::Module graph_plan_memory_module_; + /*! \brief the IR module stored which represents the executor program */ + Map tir_module_; + /*! \brief the set of statements that make the program */ + std::vector stmts_; + /*! \brief the list of return sids (note that the function might return more then one output */ + IntegerArray return_sid_; + + public: + AOTCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) + : mod_(mod), return_sid_() { + compile_engine_ = CompileEngine::Global(); + targets_ = targets; + target_host_ = target_host; + dict_attrs_.Set("global_symbol", runtime::String("tvm__run_func")); + } + + AOTLoweredOutput Codegen(relay::Function func) { + // Get the module, storage map and token sizes + auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); + storage_device_map_ = (*pf)(func); + + int input_index = 0; + for (auto input : func->params) { + input_vars_.push_back(input); + main_signature_.push_back(tir::Var(make_string("input_", input_index), DataType::Handle())); + } + + // Define the storage allocator ids + for (auto kv : storage_device_map_) { + for (const auto& sid : kv.second[0]) { + te::Var sid_var(make_string("sid_", sid), DataType::Handle()); + sids_table_[sid] = sid_var; + } + } + + // Find the return sid + return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func); + for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) { + main_signature_.push_back(tir::Var(make_string("output_", output_index), DataType::Handle())); + } + + VisitExpr(func->body); + + auto prim_func = CreateMainFunc(func->params.size()); + AOTLoweredOutput ret; + + ret.params = std::unordered_map>(); + for (auto param : params_) { + ret.params.emplace(std::make_pair( + param.first, + std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + } + + for (auto& kv : lowered_funcs_) { + if (ret.lowered_funcs.count(kv.first) == 0) { + ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); + } + auto& mod = ret.lowered_funcs[kv.first]; + mod->Update(kv.second); + ret.lowered_funcs.Set(kv.first, mod); + } + ret.external_mods = compile_engine_->LowerExternalFunctions(); + + auto target_host_str = target_host_->str(); + if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_str]->Add( + GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + } else { + Map symbol_map; + symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map)); + } + + ret.graph_tir = PrettyPrint(prim_func); + ret.aot_metadata = runtime::AOTMetadata(input_vars_.size(), return_sid_.size()); + return ret; + } +}; + +class AOTCodegenModule : public runtime::ModuleNode { + public: + AOTCodegenModule() {} + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.num_args, 3) << "The expected of arguments are: " + << "runtime::Module mod and Map targets"; + void* mod = args[0]; + Map tmp = args[1]; + tvm::Target target_host = args[2]; + TargetsMap targets; + for (const auto& it : tmp) { + auto dev_type = it.first.as(); + ICHECK(dev_type); + targets[dev_type->value] = it.second; + } + codegen_ = std::make_shared(reinterpret_cast(mod), targets, + target_host); + }); + } else if (name == "codegen") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Function func = args[0]; + this->output_ = this->codegen_->Codegen(func); + }); + } else if (name == "get_graph") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_tir; }); + } else if (name == "list_params_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Array ret; + for (const auto& kv : this->output_.params) { + ret.push_back(kv.first); + } + *rv = ret; + }); + } else if (name == "get_param_by_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + *rv = (*it).second.second; + }); + } else if (name == "get_param_id") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + *rv = (*it).second.first; + }); + } else if (name == "get_irmodule") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.lowered_funcs; + }); + } else if (name == "get_external_modules") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.external_mods; + }); + } else if (name == "get_aot_metadata") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.aot_metadata; + }); + } else { + return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); + } + } + + const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } + + private: + std::shared_ptr codegen_; + AOTLoweredOutput output_; +}; + +runtime::Module CreateAOTCodegenMod() { + auto ptr = make_object(); + return runtime::Module(ptr); +} + +TVM_REGISTER_GLOBAL("relay.build_module._GraphAOTCodegen") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTCodegenMod(); }); + +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 08846925bede7..9cf071a0d30c0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -43,12 +43,14 @@ namespace backend { using TargetsMap = Map; using namespace tvm::relay::transform; +enum class Executor { GraphRuntime, Aot }; + /*! * \brief Output of building module * */ struct BuildOutput { - std::string graph_json; + std::string graph; runtime::Module mod; std::unordered_map params; }; @@ -59,17 +61,35 @@ struct BuildOutput { */ struct GraphCodegen { public: - GraphCodegen() { - auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen"); - mod = (*pf)(); + explicit GraphCodegen(Target target_host) : target_host_(target_host) { + const String executor_str = target_host->GetAttr("executor").value_or("graph_runtime"); + if (executor_str == "graph_runtime") { + executor_ = Executor::GraphRuntime; + auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen"); + mod = (*pf)(); + } else if (executor_str == "aot") { + executor_ = Executor::Aot; + auto pf = GetPackedFunc("relay.build_module._GraphAOTCodegen"); + mod = (*pf)(); + } else { + LOG(FATAL) << "Executor not supported"; + } } ~GraphCodegen() {} - void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } + void Init(runtime::Module* m, TargetsMap targets) { + if (executor_ == Executor::GraphRuntime) { + CallFunc("init", m, targets); + } else if (executor_ == Executor::Aot) { + CallFunc("init", m, targets, target_host_); + } else { + LOG(FATAL) << "Executor not supported"; + } + } void Codegen(const Function& func) { CallFunc("codegen", func); } - std::string GetJSON() { return CallFunc("get_graph_json", nullptr); } + std::string GetGraph() { return CallFunc("get_graph", nullptr); } Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); @@ -101,7 +121,18 @@ struct GraphCodegen { return ret; } + runtime::AOTMetadata GetAOTMetdata() { + if (executor_ == Executor::Aot) { + return CallFunc("get_aot_metadata"); + } else { + // Graph runtime does not need AOT metadata + return runtime::AOTMetadata(); + } + } + protected: + Executor executor_; + Target target_host_; tvm::runtime::Module mod; template R CallFunc(const std::string& name, Args... args) { @@ -129,9 +160,9 @@ class RelayBuildModule : public runtime::ModuleNode { * \return The corresponding member function. */ PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "get_graph_json") { + if (name == "get_graph") { return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); }); + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraph(); }); } else if (name == "get_module") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); @@ -177,7 +208,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return const std::string graph_json */ - const std::string& GetGraphJSON() { return ret_.graph_json; } + const std::string& GetGraph() { return ret_.graph; } /*! * \brief Get the Module object @@ -462,25 +493,26 @@ class RelayBuildModule : public runtime::ModuleNode { const std::unordered_map& params) { // Relay IRModule -> IRModule optimizations. relay_module = Optimize(relay_module, targets_, params); + + Target target_host = GetTargetHost(); + // If no target_host has been set, we choose a default one, which is + // llvm if "codegen.LLVMModuleCreate" is accessible. + const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); + if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); + // Get the updated function. auto func = Downcast(relay_module->Lookup("main")); // Generate code for the updated function. - graph_codegen_ = std::unique_ptr(new GraphCodegen()); + graph_codegen_ = std::unique_ptr(new GraphCodegen(target_host)); graph_codegen_->Init(nullptr, targets_); graph_codegen_->Codegen(func); - ret_.graph_json = graph_codegen_->GetJSON(); + ret_.graph = graph_codegen_->GetGraph(); ret_.params = graph_codegen_->GetParams(); auto lowered_funcs = graph_codegen_->GetIRModule(); - Target target_host = GetTargetHost(); - // If no target_host has been set, we choose a default one, which is - // llvm if "codegen.LLVMModuleCreate" is accessible. - const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); - if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - // Generate a placeholder function that attaches linked params as its arguments. if (target_host->GetAttr("link-params").value_or(Bool(false))) { CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; @@ -519,7 +551,8 @@ class RelayBuildModule : public runtime::ModuleNode { } auto ext_mods = graph_codegen_->GetExternalModules(); - ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost()); + ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost(), + graph_codegen_->GetAOTMetdata()); } private: diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 15173c2c79db8..d2594fe97c02f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -209,6 +209,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (const auto& kv : token_map_) { std::vector storage_ids; std::vector device_types; + std::vector sid_sizes; for (StorageToken* tok : kv.second) { if (tok->device_type) { num_annotated_nodes++; @@ -216,8 +217,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor { num_nodes++; storage_ids.push_back(tok->storage_id); device_types.push_back(tok->device_type); + sid_sizes.push_back(GetMemorySize(tok)); } - smap.Set(GetRef(kv.first), Array({storage_ids, device_types})); + smap.Set(GetRef(kv.first), Array({storage_ids, device_types, sid_sizes})); } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7ed1504951048..9a709c22cda62 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -250,7 +250,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator storage_info; for (auto& v : storage_device_info[0]) { @@ -581,7 +581,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { Function func = args[0]; this->output_ = this->codegen_->Codegen(func); }); - } else if (name == "get_graph_json") { + } else if (name == "get_graph") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; }); } else if (name == "list_params_name") { diff --git a/src/runtime/crt/aot/tvm_executor.c b/src/runtime/crt/aot/tvm_executor.c new file mode 100644 index 0000000000000..74069c6af26e3 --- /dev/null +++ b/src/runtime/crt/aot/tvm_executor.c @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE + +/*! + * \file src/runtime/crt/aot/tvm_executor.c + * \brief Internal implementation of the AOT Executor + */ + +#include "tvm_executor.h" + +#include + +#include "tvm_backend.h" +#include "tvm_error.h" + +tvm_workspace_t* tvm_runtime_workspace; + +tvm_crt_error_t tvm_runtime_run(const tvm_model_t* model, void** inputs, void** outputs, + tvm_context_t* context) { + static DLContext fake_ctx = {kDLCPU, 0}; + static int64_t fake_dims = 0; + static int64_t fake_shape = {0}; + + DLTensor tensors[model->num_input_tensors + model->num_output_tensors]; // NOLINT + TVMValue tvm_values[model->num_input_tensors + model->num_output_tensors]; // NOLINT + int32_t tvm_typeids[model->num_input_tensors + model->num_output_tensors]; // NOLINT + + for (int i = 0; i < model->num_input_tensors; i++) { + tensors[i] = (DLTensor){ + .ctx = fake_ctx, + .data = inputs[i], + .shape = &fake_shape, + .ndim = fake_dims, + .byte_offset = 0, + .strides = NULL, + }; + tvm_values[i].v_handle = &tensors[i]; + } + + for (int i = 0; i < model->num_output_tensors; i++) { + tensors[model->num_input_tensors + i] = (DLTensor){ + .ctx = fake_ctx, + .data = outputs[i], + .shape = &fake_shape, + .ndim = fake_dims, + .byte_offset = 0, + .strides = NULL, + }; + tvm_values[model->num_input_tensors + i].v_handle = &tensors[model->num_input_tensors + i]; + } + + return model->run_func(&tvm_values, &tvm_typeids, 0, NULL, 0, context); +} + +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, + int dtype_bits_hint) { + uint32_t offset = (~nbytes + 1) & (TVM_RUNTIME_ALLOC_ALIGNMENT - 1); + uint8_t* current_alloc = tvm_runtime_workspace->next_alloc; + uint8_t* next_alloc = tvm_runtime_workspace->next_alloc + nbytes + offset; + uint8_t* workspace_end = tvm_runtime_workspace->workspace + tvm_runtime_workspace->workspace_size; + + if (next_alloc > workspace_end) { + return NULL; + } + + tvm_runtime_workspace->next_alloc = next_alloc; + return current_alloc; +} + +tvm_crt_error_t TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { + tvm_runtime_workspace->next_alloc = ptr; + return 0; +} diff --git a/src/runtime/crt/graph_runtime/graph_runtime.c b/src/runtime/crt/graph_runtime/graph_runtime.c index 21b72f0e400c0..5fe993d8a766b 100644 --- a/src/runtime/crt/graph_runtime/graph_runtime.c +++ b/src/runtime/crt/graph_runtime/graph_runtime.c @@ -874,6 +874,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob, void TVMGraphRuntime_Run(TVMGraphRuntime* runtime) { // setup the array and requirements. uint32_t idx; + for (idx = 0; idx < runtime->op_execs_count; ++idx) { if (runtime->op_execs[idx].fexec) { #if TVM_CRT_DEBUG diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 03dba399fcb40..aa819ea2343ca 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -32,6 +32,7 @@ #include #include +#include #include #include "runtime_base.h" @@ -39,6 +40,37 @@ namespace tvm { namespace runtime { +/*! + * \brief Structure used by the AOT to fill the tvm_module_t structure + */ +class AOTMetadataNode : public Object { + public: + /*! \brief number of inputs of the main function */ + int num_inputs = 1; + /*! \brief number of outputs of the main function */ + int num_outputs = 1; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "AOTMetadataObj"; + TVM_DECLARE_FINAL_OBJECT_INFO(AOTMetadataNode, Object); +}; + +/*! + * \brief Managed reference to AOTMetadataNode. + */ +class AOTMetadata : public ObjectRef { + public: + TVM_DLL AOTMetadata(int num_inputs, int num_outputs) { + auto n = make_object(); + n->num_inputs = num_inputs; + n->num_outputs = num_outputs; + data_ = std::move(n); + } + + TVM_DEFINE_OBJECT_REF_METHODS(AOTMetadata, ObjectRef, AOTMetadataNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AOTMetadataNode); +}; + /*! * \brief Create a metadata module object. * diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 0b30d42c876ce..aa52858d0b931 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -46,7 +46,8 @@ namespace codegen { */ runtime::Module CreateMetadataModule( const std::unordered_map& params, - tvm::runtime::Module target_module, const Array& ext_modules, Target target) { + tvm::runtime::Module target_module, const Array& ext_modules, Target target, + runtime::AOTMetadata aot_metadata) { // Here we split modules into two groups: // 1. Those modules which can be exported to C-runtime. These are DSO-exportable // (i.e. llvm or c) modules which return nothing from get_const_vars(). @@ -114,7 +115,7 @@ runtime::Module CreateMetadataModule( if (target->kind->name == "c") { crt_exportable_modules.push_back(target_module); - target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target); + target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target, aot_metadata); } else if (target->kind->name == "llvm") { #ifdef TVM_LLVM_VERSION crt_exportable_modules.push_back(target_module); diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index 83cb29dd5a461..49404a63fdeb5 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -33,12 +33,15 @@ #include #include +#include "../runtime/meta_data.h" + namespace tvm { namespace codegen { runtime::Module CreateMetadataModule( const std::unordered_map& params, - tvm::runtime::Module target_module, const Array& ext_modules, Target target); + tvm::runtime::Module target_module, const Array& ext_modules, Target target, + int num_inputs = 1, int num_outputs = 1); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 3ec64ed2ace92..0ea8f2250f894 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -40,13 +40,21 @@ namespace codegen { CodeGenCHost::CodeGenCHost() { module_name_ = GetUniqueName("__tvm_module_ctx"); } -void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_str) { +void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool is_aot_executor, + std::string target_str) { emit_asserts_ = emit_asserts; + is_aot_executor_ = is_aot_executor; declared_globals_.clear(); decl_stream << "// tvm target: " << target_str << "\n"; decl_stream << "#define TVM_EXPORTS\n"; - decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; - decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + if (is_aot_executor) { + decl_stream << "#include \"tvm_executor.h\"\n"; + decl_stream << "#include \"dlpack/dlpack.h\"\n"; + } else { + decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; + decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + } + decl_stream << "#include \n"; decl_stream << "void* " << module_name_ << " = NULL;\n"; CodeGenC::Init(output_ssa); @@ -211,21 +219,34 @@ void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name, this->stream << "}\n"; } -void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_args) { +void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, PrimExpr values, + int num_args) { this->PrintIndent(); + std::string stack_value = "stack_value"; + if (const VarNode* stack_value_var = values.as()) { + stack_value = stack_value_var->name_hint; + } std::string ret_val = GetUniqueName("ret_val"); std::string ret_type_code = GetUniqueName("ret_type_code"); this->stream << "TVMValue " << ret_val << ";\n"; this->PrintIndent(); this->stream << "int " << ret_type_code << ";\n"; this->PrintIndent(); - this->stream << "if (TVMFuncCall(" << packed_func_name << ", " - << "(TVMValue*) stack_value" - << ", " + + if (is_aot_executor_) { + this->stream << "if (" << packed_func_name << "( " + << "(TVMValue*) " << stack_value; + } else { + this->stream << "if (TVMFuncCall(" << packed_func_name << ", " + << "(TVMValue*) stack_value"; + } + this->stream << ", " << "(int*) stack_tcode" << ", " << num_args << ", " - << "&" << ret_val << ", " - << "&" << ret_type_code << ") != 0) {\n"; + << "&" << ret_val << ", "; + this->stream << "&" << ret_type_code; + this->stream << (is_aot_executor_ ? ", NULL" : "") << ") != 0) {\n"; + int func_call_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -274,8 +295,11 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT << "Expected name " << packed_func_name << " to not be taken"; decl_stream << "static void* " << packed_func_name << " = NULL;\n"; } - this->PrintGetFuncFromBackend(func_name, packed_func_name); - this->PrintFuncCall(packed_func_name, num_args); + if (!is_aot_executor_) { + this->PrintGetFuncFromBackend(func_name, packed_func_name); + } + this->PrintFuncCall(func_name, op->args[1], num_args); + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PrintIndent(); this->stream << "return -1;\n"; @@ -324,15 +348,19 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, } runtime::Module BuildCHost(IRModule mod, Target target) { + bool is_aot_executor = (target->GetAttr("executor").value_or("graph_runtime") == "aot"); + using tvm::runtime::Registry; bool output_ssa = false; bool emit_asserts = false; CodeGenCHost cg; - cg.Init(output_ssa, emit_asserts, target->str()); + cg.Init(output_ssa, emit_asserts, is_aot_executor, target->str()); Map linked_params; bool found_linked_params = false; bool could_have_linked_params = target->GetAttr("link-params").value_or(Bool(false)); + PrimFunc aot_executor_fn; + for (auto kv : mod->functions) { if (could_have_linked_params && kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { @@ -344,6 +372,17 @@ runtime::Module BuildCHost(IRModule mod, Target target) { found_linked_params = true; continue; } + // Make sure that the executor function is the last one to be code generated so that all the + // symbols are available to tvm_run_func + if (is_aot_executor) { + auto fun_name = std::string(kv.first->name_hint); + const bool is_aot_executor_fn = + (fun_name.rfind(::tvm::runtime::symbol::tvm_run_func_prefix, 0) == 0); + if (is_aot_executor_fn) { + aot_executor_fn = Downcast(kv.second); + continue; + } + } ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto f = Downcast(kv.second); @@ -355,6 +394,12 @@ runtime::Module BuildCHost(IRModule mod, Target target) { cg.LinkParameters(linked_params); } + if (is_aot_executor) { + ICHECK(aot_executor_fn.defined()) + << "When using aot executor the executor function should be defined"; + cg.AddFunction(aot_executor_fn); + } + if (target->GetAttr("system-lib").value_or(Bool(false))) { ICHECK_EQ(target->GetAttr("runtime").value_or(""), "c") << "c target only supports generating C runtime SystemLibs"; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 97fe7ab39efac..06868280df3c2 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -38,7 +38,7 @@ namespace codegen { class CodeGenCHost final : public CodeGenC { public: CodeGenCHost(); - void Init(bool output_ssa, bool emit_asserts, std::string target_str); + void Init(bool output_ssa, bool emit_asserts, bool is_aot_executor, std::string target_str); void AddFunction(const PrimFunc& f); @@ -69,9 +69,10 @@ class CodeGenCHost final : public CodeGenC { Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; + bool is_aot_executor_; void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); - void PrintFuncCall(const std::string& packed_func_name, int num_args); + void PrintFuncCall(const std::string& packed_func_name, PrimExpr values, int num_args); /*! * \brief Print ternary conditional operator implementing binary `op` diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 3baa44eb639fa..e91d78f580f24 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -155,7 +155,8 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, */ runtime::Module CreateMetadataModule( const std::unordered_map& params, runtime::Module target_module, - const Array& ext_modules, Target target); + const Array& ext_modules, Target target, + runtime::AOTMetadata aot_metadata = runtime::AOTMetadata()); /*! * \brief Create a source module for viewing and limited saving for device. @@ -175,8 +176,8 @@ runtime::Module DeviceSourceModuleCreate( * \param target the target the modules are compiled for. * \return The wrapped module. */ -runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, - Target target); +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + runtime::AOTMetadata aot_metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 26f1850c0e475..68de392e06f60 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -130,8 +130,8 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { public: CSourceCrtMetadataModuleNode(const Array& func_names, const std::string& fmt, - Target target) - : fmt_(fmt), func_names_(func_names), target_(target) { + Target target, runtime::AOTMetadata aot_metadata) + : fmt_(fmt), func_names_(func_names), target_(target), aot_metadata_(aot_metadata) { CreateSource(); } const char* type_key() const { return "c"; } @@ -159,6 +159,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { std::string fmt_; Array func_names_; Target target_; + runtime::AOTMetadata aot_metadata_; void CreateFuncRegistry() { code_ << "#include \n"; @@ -191,17 +192,35 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << "}\n"; } + void GenerateAOTDescriptor() { + code_ << "#include \n"; + code_ << "#ifdef __cplusplus\n"; + code_ << "extern \"C\"\n"; + code_ << "#endif\n"; + code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix; + code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " + "out_type_code, void* resource_handle);\n"; + code_ << "const tvm_model_t network = {\n" + << " .run_func = &" << ::tvm::runtime::symbol::tvm_run_func_prefix << ",\n" + << " .num_input_tensors = " << aot_metadata_->num_inputs << ",\n" + << " .num_output_tensors = " << aot_metadata_->num_outputs << ", \n" + << "};\n"; + } + void CreateSource() { if (target_->GetAttr("system-lib").value_or(Bool(false)) && !func_names_.empty()) { CreateFuncRegistry(); GenerateCrtSystemLib(); } + if (target_->GetAttr("executor").value_or("graph_runtime") == "aot") { + GenerateAOTDescriptor(); + } code_ << ";"; } }; -runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, - Target target) { +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + runtime::AOTMetadata aot_metadata) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -212,7 +231,7 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod } } } - auto n = make_object(func_names, "cc", target); + auto n = make_object(func_names, "cc", target, aot_metadata); auto csrc_metadata_module = runtime::Module(n); for (const auto& mod : modules) { csrc_metadata_module.Import(mod); @@ -283,7 +302,8 @@ TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule") .set_body_typed([](const Array& modules, Target target) { - return CreateCSourceCrtMetadataModule(modules, target); + // Note that we don't need metadata when we compile a single operator + return CreateCSourceCrtMetadataModule(modules, target, runtime::AOTMetadata()); }); } // namespace codegen diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index 45858b9f4ef25..f4f52eccd1dd7 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -29,6 +29,8 @@ #include #include +#include "../../runtime/meta_data.h" + namespace tvm { namespace codegen { @@ -38,7 +40,8 @@ namespace codegen { * \param target TVM target. */ runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, - tvm::Target target); + tvm::Target target, + runtime::AOTMetadata aot_metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 863d99993f4a9..a7e418927c76a 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -227,6 +227,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("runtime") .add_attr_option("mcpu") .add_attr_option("march") + .add_attr_option("executor") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("cuda", kDLGPU) @@ -308,8 +309,7 @@ TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev) // line break TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break .add_attr_option("system-lib"); -TVM_REGISTER_TARGET_KIND("composite", kDLCPU) - .add_attr_option>("devices"); +TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("devices"); /********** Registry **********/ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 1117571c8b756..c4d76d4a74940 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -174,6 +174,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(tvm_call_unpacked) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 1d12d57d10b42..06b205033d595 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -247,7 +247,7 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], stack_value_, stack_tcode_, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args); + return Call(op->dtype, builtin::tvm_call_packed_lowered(), packed_args); } PrimExpr MakeCallTracePacked(const CallNode* op) { diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index a15cdcd3926b7..641fe12bb6e2f 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -112,7 +112,7 @@ TEST(Relay, BuildModule) { auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); tvm::runtime::Module build_mod = (*pfb)(); auto build_f = build_mod.GetFunction("build", false); - auto json_f = build_mod.GetFunction("get_graph_json", false); + auto json_f = build_mod.GetFunction("get_graph", false); auto mod_f = build_mod.GetFunction("get_module", false); Map targets; Target llvm_tgt = Target("llvm"); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index e55431fe2413c..c46a1001e264c 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -85,7 +85,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); tvm::runtime::Module build_mod = (*pfb)(); auto build_f = build_mod.GetFunction("build", false); - auto json_f = build_mod.GetFunction("get_graph_json", false); + auto json_f = build_mod.GetFunction("get_graph", false); auto mod_f = build_mod.GetFunction("get_module", false); Map targets; diff --git a/tests/crt/aot_executor_test.cc b/tests/crt/aot_executor_test.cc new file mode 100644 index 0000000000000..753d9d9dc4deb --- /dev/null +++ b/tests/crt/aot_executor_test.cc @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "tvm_executor.h" + +int32_t test_run_func(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, + void* out_ret_tcode, void* resource_handle) { + return kTvmErrorNoError; +} + +TEST(AOTRuntime, NoOp) { + const tvm_model_t test_model = { + .num_input_tensors = 0, + .num_output_tensors = 0, + .run_func = &test_run_func, + }; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&test_model, NULL, NULL, NULL)); +} + +int32_t error_run_func(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, + void* out_ret_tcode, void* resource_handle) { + return kTvmErrorPlatformNoMemory; +} + +TEST(AOTRuntime, Error) { + const tvm_model_t error_model = { + .num_input_tensors = 0, + .num_output_tensors = 0, + .run_func = &error_run_func, + }; + + ASSERT_EQ(kTvmErrorPlatformNoMemory, tvm_runtime_run(&error_model, NULL, NULL, NULL)); +} + +int32_t identity_run_func(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, + void* out_ret_tcode, void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* T_id = (((DLTensor*)arg1)[0].data); + ((uint32_t*)T_id)[(0)] = ((uint32_t*)placeholder)[(0)]; + return kTvmErrorNoError; +} + +TEST(AOTRuntime, Identity) { + const tvm_model_t identity_model = { + .num_input_tensors = 1, + .num_output_tensors = 1, + .run_func = &identity_run_func, + }; + + uint32_t inputs1[1] = {404}; + void* inputs[] = {inputs1}; + uint32_t outputs1[1]; + void* outputs[] = {outputs1}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&identity_model, inputs, outputs, NULL)); + ASSERT_EQ(outputs1[0], 404); +} + +int32_t add_run_func(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, + void* out_ret_tcode, void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* T_add = (((DLTensor*)arg1)[0].data); + ((uint32_t*)T_add)[(0)] = ((uint32_t*)placeholder)[(0)] + ((uint32_t*)placeholder)[(1)]; + return kTvmErrorNoError; + + return kTvmErrorNoError; +} + +TEST(AOTRuntime, Add) { + const tvm_model_t add_model = { + .num_input_tensors = 1, + .num_output_tensors = 1, + .run_func = &add_run_func, + }; + + uint32_t inputs1[2] = {404, 500}; + void* inputs[] = {inputs1}; + uint32_t outputs1[1]; + void* outputs[] = {outputs1}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&add_model, inputs, outputs, NULL)); + ASSERT_EQ(outputs1[0], 904); +} + +int32_t multiple_inputs_run_func(void* args, void* arg_type_ids, int32_t num_args, + void* out_ret_value, void* out_ret_tcode, void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* arg2 = (((TVMValue*)args)[2].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* placeholder1 = (((DLTensor*)arg1)[0].data); + void* T_add = (((DLTensor*)arg2)[0].data); + ((uint32_t*)T_add)[(0)] = ((uint32_t*)placeholder)[(0)] + ((uint32_t*)placeholder)[(1)] + + ((uint32_t*)placeholder1)[(0)] + ((uint32_t*)placeholder1)[(1)]; + return kTvmErrorNoError; +} + +TEST(AOTRuntime, MultipleInputs) { + const tvm_model_t multiple_inputs_model = { + .num_input_tensors = 2, + .num_output_tensors = 1, + .run_func = &multiple_inputs_run_func, + }; + + uint32_t inputs1[2] = {404, 500}; + uint32_t inputs2[2] = {200, 202}; + void* inputs[] = {inputs1, inputs2}; + + uint32_t outputs1[1]; + void* outputs[] = {outputs1}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_inputs_model, inputs, outputs, NULL)); + ASSERT_EQ(outputs1[0], 1306); +} + +int32_t multiple_outputs_run_func(void* args, void* arg_type_ids, int32_t num_args, + void* out_ret_value, void* out_ret_tcode, void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* arg2 = (((TVMValue*)args)[2].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* T_split1 = (((DLTensor*)arg1)[0].data); + void* T_split2 = (((DLTensor*)arg2)[0].data); + ((uint32_t*)T_split1)[(0)] = ((uint32_t*)placeholder)[(0)]; + ((uint32_t*)T_split2)[(0)] = ((uint32_t*)placeholder)[(1)]; + return kTvmErrorNoError; +} + +TEST(AOTRuntime, MultipleOutputs) { + const tvm_model_t multiple_outputs_model = { + .num_input_tensors = 1, + .num_output_tensors = 2, + .run_func = &multiple_outputs_run_func, + }; + + uint32_t inputs1[2] = {404, 500}; + void* inputs[] = {inputs1}; + + uint32_t outputs1[1]; + uint32_t outputs2[1]; + void* outputs[] = {outputs1, outputs2}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_outputs_model, inputs, outputs, NULL)); + ASSERT_EQ(outputs1[0], 404); + ASSERT_EQ(outputs2[0], 500); +} + +int32_t resource_handle_check_run_func(void* args, void* arg_type_ids, int32_t num_args, + void* out_ret_value, void* out_ret_tcode, + void* resource_handle) { + if (resource_handle == NULL) { + return kTvmErrorFunctionCallWrongArgType; + } + return kTvmErrorNoError; +} + +TEST(AOTRuntime, ContextPassing) { + tvm_context_t stub_context = {}; + const tvm_model_t resource_handle_check_model = { + .num_input_tensors = 0, + .num_output_tensors = 0, + .run_func = &resource_handle_check_run_func, + }; + + ASSERT_EQ(kTvmErrorNoError, + tvm_runtime_run(&resource_handle_check_model, NULL, NULL, &stub_context)); + ASSERT_EQ(kTvmErrorFunctionCallWrongArgType, + tvm_runtime_run(&resource_handle_check_model, NULL, NULL, NULL)); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/crt/aot_memory_test.cc b/tests/crt/aot_memory_test.cc new file mode 100644 index 0000000000000..a5df9a5b64775 --- /dev/null +++ b/tests/crt/aot_memory_test.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "tvm_backend.h" + +// TODO(Mousius) - Move memory allocation to individual networks +extern tvm_workspace_t* tvm_runtime_workspace; + +/* + * Tests allocations are properly aligned when allocated + */ +TEST(AOTMemory, Allocate) { + static uint8_t model_memory[80]; + tvm_workspace_t workspace = { + .next_alloc = model_memory, + .workspace = model_memory, + .workspace_size = 80, + }; + tvm_runtime_workspace = &workspace; + + void* block_one = TVMBackendAllocWorkspace(0, 0, 1, 0, 0); + ASSERT_EQ(block_one, &model_memory[0]); + + void* block_two = TVMBackendAllocWorkspace(0, 0, 2, 0, 0); + ASSERT_EQ(block_two, &model_memory[16]); + + void* two_blocks = TVMBackendAllocWorkspace(0, 0, 24, 0, 0); + ASSERT_EQ(two_blocks, &model_memory[32]); + + void* block_three = TVMBackendAllocWorkspace(0, 0, 1, 0, 0); + ASSERT_EQ(block_three, &model_memory[64]); +} + +/* + * Tests resetting the stack after dealloc + */ +TEST(AOTMemory, Free) { + static uint8_t model_memory[80]; + tvm_workspace_t workspace = { + .next_alloc = model_memory, + .workspace = model_memory, + .workspace_size = 80, + }; + tvm_runtime_workspace = &workspace; + + void* block_one = TVMBackendAllocWorkspace(0, 0, 1, 0, 0); + ASSERT_EQ(block_one, &model_memory[0]); + + void* block_two = TVMBackendAllocWorkspace(0, 0, 1, 0, 0); + ASSERT_EQ(block_two, &model_memory[16]); + ASSERT_EQ(0, TVMBackendFreeWorkspace(0, 0, block_two)); + + void* two_blocks = TVMBackendAllocWorkspace(0, 0, 2, 0, 0); + ASSERT_EQ(two_blocks, &model_memory[16]); + ASSERT_EQ(0, TVMBackendFreeWorkspace(0, 0, two_blocks)); + + void* block_three = TVMBackendAllocWorkspace(0, 0, 1, 0, 0); + ASSERT_EQ(block_three, &model_memory[16]); +} + +/* + * Tests we return NULL if we over allocate + */ +TEST(AOTMemory, OverAllocate) { + static uint8_t model_memory[72]; + tvm_workspace_t workspace = { + .next_alloc = model_memory, + .workspace = model_memory, + .workspace_size = 72, + }; + tvm_runtime_workspace = &workspace; + + void* block_one = TVMBackendAllocWorkspace(0, 0, 1, 0, 0); + ASSERT_EQ(block_one, &model_memory[0]); + + void* block_two = TVMBackendAllocWorkspace(0, 0, 1, 0, 0); + ASSERT_EQ(block_two, &model_memory[16]); + + void* two_blocks = TVMBackendAllocWorkspace(0, 0, 64, 0, 0); + ASSERT_EQ(two_blocks, (void*)NULL); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/aot/aot_test.mk b/tests/python/relay/aot/aot_test.mk new file mode 100644 index 0000000000000..66dd6e6ae21fd --- /dev/null +++ b/tests/python/relay/aot/aot_test.mk @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Makefile to build ethosu_test_runner +# Setup build environment +# +AOT_ROOT ?= $(TVM_ROOT)/src/runtime/crt/aot + +ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 +DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core +PKG_COMPILE_OPTS = -g +CC = gcc +AR = ar +RANLIB = ranlib +CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB) + + +PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ + -I$(TVM_ROOT)/include/tvm/runtime/crt/aot \ + -I$(TVM_ROOT)/src/runtime/crt/include \ + -I$(DMLC_CORE)/include \ + -I$(TVM_ROOT)/3rdparty/dlpack/include \ + -I$(AOT_ROOT)\ + -I$(build_dir) + +$(ifeq VERBOSE,1) +QUIET ?= +$(else) +QUIET ?= @ +$(endif) + +CRT_SRCS = $(shell find $(CRT_ROOT)) + +aot_test_runner: $(build_dir)/aot_test_runner + +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/lib0.o $(build_dir)/lib1.o $(build_dir)/tvm_executor.o + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm + +$(build_dir)/lib1.o: $(build_dir)/../codegen/host/src/lib1.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +$(build_dir)/lib0.o: $(build_dir)/../codegen/host/src/lib0.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +$(build_dir)/tvm_executor.o: $(TVM_ROOT)/src/runtime/crt/aot/tvm_executor.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +clean: + $(QUIET)rm -rf $(build_dir)/crt +cleanall: + $(QUIET)rm -rf $(build_dir) +# Don't define implicit rules; they tend to match on logical target names that aren't targets (i.e. bundle_static) +.SUFFIXES: +.DEFAULT: ethosu_test_runner diff --git a/tests/python/relay/aot/infra.py b/tests/python/relay/aot/infra.py new file mode 100644 index 0000000000000..475b150ccd651 --- /dev/null +++ b/tests/python/relay/aot/infra.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module provides infrastructure to verify the correctness of +the command stream produced. +Currently it will invoke vela to generate a vela-optimized tflite +in which the command stream is contained as a custom operator. +This class include methods to parse the custom operator to extract +the command stream and perform an equivalency check for single operator +test cases. +""" +import tflite +import os +import io +import struct +import numpy as np +import pathlib +import shutil +import subprocess +import tempfile +import tarfile + + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.op.contrib import get_pattern_table +from tvm.contrib import utils +from tvm.relay.backend import compile_engine +from tvm.contrib import utils +from tvm.contrib import graph_runtime +from tvm.micro import export_model_library_format + + +def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): + """ + This method runs a process and logs the output to both a log file and stdout + """ + with subprocess.Popen( + cmd, cwd=cwd, shell=True, bufsize=0, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) as proc, open(logfile, "a") as f: + while True: + data = proc.stdout.readline() + result = proc.poll() + # process is done if there is no data and the result is valid + if data == b"" and result is not None: + return int(result) + if data: + text = data.decode("ascii", errors="backslashreplace") + f.write(text) + if stdout: + print(text, end="") + + +def create_main(test_name, input_list, output_list, output_path): + file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() + # create header file + raw_path = file_path.with_suffix(".c").resolve() + with open(raw_path, "w") as main_file: + main_file.write("#include \n") + main_file.write("#include \n") + main_file.write("#define WORKSPACE_SIZE (16384*1024)\n") + main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") + + for i in range(0, len(input_list)): + main_file.write('#include "input_data%i.h"\n' % i) + for i in range(0, len(output_list)): + main_file.write('#include "expected_output_data%i.h"\n' % i) + main_file.write('#include "output_data%i.h"\n' % i) + + main_file.write("extern tvm_model_t network;\n") + main_file.write("extern tvm_workspace_t *tvm_runtime_workspace;\n") + main_file.write("int main(){\n") + main_file.write("void* inputs[%i] = { " % (len(input_list))) + + for i in range(0, len(input_list)): + main_file.write("input_data%i, " % i) + main_file.write("};\n") + + main_file.write("void* outputs[%i] = { " % (len(output_list))) + for i in range(0, len(output_list)): + main_file.write("output_data%i, " % i) + main_file.write("};\n") + + main_file.write("") + main_file.write( + "tvm_workspace_t app_workspace = {.next_alloc=g_aot_memory, .workspace=g_aot_memory, .workspace_size=WORKSPACE_SIZE};\n" + ) + main_file.write("tvm_runtime_workspace = &app_workspace;\n") + main_file.write("tvm_runtime_run(&network, inputs, outputs, NULL);") + + for i in range(0, len(output_list)): + main_file.write("for (int i = 0; i\n") + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + + if npy_data.dtype == "int8": + header_file.write(f"int8_t {tensor_name}[] =") + elif npy_data.dtype == "int32": + header_file.write(f"int32_t {tensor_name}[] = ") + elif npy_data.dtype == "uint8": + header_file.write(f"uint8_t {tensor_name}[] = ") + elif npy_data.dtype == "float32": + header_file.write(f"float {tensor_name}[] = ") + + header_file.write("{") + for i in np.ndindex(npy_data.shape): + header_file.write(f"{npy_data[i]}, ") + header_file.write("};\n\n") + + +def verify_source(mod, input_list, output_list, params=None): + """ + This method verifies the generated source + """ + target = "c -runtime=c --link-params --executor=aot" + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lib = tvm.relay.build(mod, target, target_host=target, params=params) + + tmp_path = utils.tempdir() + tmp_dir = tmp_path.temp_dir + + base_path = os.path.join(tmp_dir, "test") + build_path = os.path.join(base_path, "build") + os.makedirs(build_path, exist_ok=True) + + tar_file = os.path.join(base_path, "test.tar") + export_model_library_format(lib, tar_file) + t = tarfile.open(tar_file) + t.extractall(base_path) + + for i in range(len(input_list)): + create_header_file((f"input_data{i}"), input_list[i], build_path) + + for i in range(len(output_list)): + create_header_file( + (f"output_data{i}"), + np.zeros(output_list[i].shape, output_list[i].dtype), + build_path, + ) + create_header_file((f"expected_output_data{i}"), output_list[i], build_path) + + create_main("test.c", input_list, output_list, build_path) + + # Verify that compiles fine + file_dir = os.path.dirname(os.path.abspath(__file__)) + makefile = os.path.join(file_dir, "aot_test.mk") + make_cmd = f"make -f {makefile} build_dir=" + build_path + f" TVM_ROOT={file_dir}/../../../.." + + compile_log_path = os.path.join(build_path, "test_compile.log") + ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) + assert ret == 0 + + # Verify that runs fine + run_log_path = os.path.join(build_path, "test_run.log") + ret = subprocess_with_stdout_and_log("./aot_test_runner", build_path, run_log_path, False) + assert ret == 0 + + +def generate_ref_data(mod, input_data, params=None, target="llvm"): + """Generate reference data through executing the relay module""" + compile_engine.get().clear() + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + + lib_name = "mod.so" + temp = utils.tempdir() + lib_path = temp.relpath(lib_name) + lib.export_library(lib_path) + lib = tvm.runtime.load_module(lib_path) + grt_mod = graph_runtime.GraphModule(lib["default"](tvm.cpu())) + grt_mod.set_input(**input_data) + grt_mod.run() + output_count = grt_mod.get_num_outputs() + out = [grt_mod.get_output(i).asnumpy() for i in range(output_count)] + return out diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py new file mode 100644 index 0000000000000..b6480e039c61d --- /dev/null +++ b/tests/python/relay/aot/test_crt_aot.py @@ -0,0 +1,258 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tflite +import os +import io +import struct +import numpy as np +import pathlib +import shutil +import subprocess +import tempfile +import tarfile + + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.op.contrib import get_pattern_table +from tvm.contrib import utils +from tvm.relay.backend import compile_engine +from tvm.contrib import utils +from tvm.contrib import graph_runtime +from tvm.micro import export_model_library_format +from tvm.relay import testing + +from infra import * + + +def test_conv_with_params(): + RELAY_MODEL = """ +#[version = "0.0.5"] +def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { + %1 = nn.conv2d( + %data, + %weight, + padding=[2, 2], + channels=8, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %1 +} +""" + mod = tvm.parser.fromtext(RELAY_MODEL) + main_func = mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + + params = {"weight": weight_data} + inputs = {"data": input_data} + output_list = generate_ref_data(mod, inputs, params) + + input_list = [input_data] + verify_source(mod, input_list, output_list, params) + + +def test_add_with_params(): + x = relay.var("x", shape=(1, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_in = np.ones((1, 10)).astype("float32") + y_in = np.random.uniform(size=(1, 10)).astype("float32") + + params = {"x": x_in} + inputs = {"y": y_in} + output_list = generate_ref_data(func, inputs, params) + + input_list = [y_in] + verify_source(func, input_list, output_list, params) + + +def test_conv2d(): + """Test a subgraph with a single conv2d operator.""" + + def conv2d_direct(): + dtype = "float32" + ishape = (1, 32, 14, 14) + w1shape = (32, 32, 3, 3) + + data0 = relay.var("data", shape=ishape, dtype=dtype) + weight0 = relay.var("weight", shape=w1shape, dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1)) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + return mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14) + + def group_conv2d(): + dtype = "float32" + ishape = (1, 32, 14, 14) + w2shape = (32, 1, 3, 3) + + data0 = relay.var("data", shape=(ishape), dtype=dtype) + weight0 = relay.var("weight", shape=(w2shape), dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=32) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w_data = np.random.uniform(0, 1, w2shape).astype(dtype) + + return mod, {"data": i_data, "weight": w_data}, (1, 32, 14, 14) + + for mod, inputs, out_shape in [conv2d_direct(), group_conv2d()]: + output_list = generate_ref_data(mod, inputs) + input_list = [inputs["data"], inputs["weight"]] + verify_source(mod, input_list, output_list) + + +def test_concatenate(): + dtype = "float32" + x = relay.var("x", shape=(10, 5), dtype=dtype) + y = relay.var("y", shape=(10, 5), dtype=dtype) + t = relay.var("z", shape=(), dtype=dtype) + z = relay.concatenate((x, y), axis=1) + z = relay.add(z, t) + # Check result. + func = relay.Function([x, y, t], z) + x_data = np.random.rand(10, 5).astype(dtype) + y_data = np.random.rand(10, 5).astype(dtype) + t_data = np.random.uniform(size=()).astype(dtype) + inputs = {"x": x_data, "y": y_data, "z": t_data} + + output_list = generate_ref_data(func, inputs) + input_list = [inputs["x"], inputs["y"], inputs["z"]] + verify_source(func, input_list, output_list) + + +def test_nested_tuples(): + x = relay.var("x", shape=(10,)) + x1 = x + relay.const(1.0) + x2 = x1 + relay.const(1.0) + x3 = x2 + relay.const(1.0) + x4 = x3 + relay.const(1.0) + out = relay.Tuple([x1, relay.Tuple([relay.Tuple([x2, x3]), x4])]) + func = relay.Function([x], out) + + x_data = np.random.uniform(size=(10,)).astype(np.float32) + inputs = {"x": x_data} + output_list = generate_ref_data(func, inputs) + input_list = [x_data] + verify_source(func, input_list, output_list) + + +def test_tuple_getitem(): + func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) + output_list = generate_ref_data(func, {}) + input_list = [] + verify_source(func, input_list, output_list) + + +def test_id(): + x = relay.var("x", "float32") + ident = relay.Function([x], x) + one = np.array(1.0, "float32") + inputs = {"x": one} + output_list = generate_ref_data(ident, inputs) + input_list = [one] + verify_source(ident, input_list, output_list) + + +def test_add_const(): + two = relay.add(relay.const(1), relay.const(1)) + func = relay.Function([], two) + output_list = generate_ref_data(func, {}) + input_list = [] + verify_source(func, input_list, output_list) + + +def test_mul_param(): + x = relay.var("x", shape=(10, 10)) + y = relay.var("y", shape=(1, 10)) + func = relay.Function([x, y], relay.multiply(x, y)) + x_data = np.random.rand(10, 10).astype("float32") + y_data = np.random.rand(1, 10).astype("float32") + inputs = {"x": x_data, "y": y_data} + output_list = generate_ref_data(func, inputs) + input_list = [inputs["x"], inputs["y"]] + verify_source(func, input_list, output_list) + + +def test_subtract(): + i = relay.var("i", shape=[], dtype="int32") + sub = relay.subtract(i, relay.const(1, dtype="int32")) + func = relay.Function([i], sub, ret_type=relay.TensorType([], "int32")) + i_data = np.array(1, dtype="int32") + inputs = {"i": i_data} + output_list = generate_ref_data(func, inputs) + input_list = [inputs["i"]] + verify_source(func, input_list, output_list) + + +def test_tuple_output(): + x = relay.var("x", shape=(6, 9)) + y = relay.split(x, 3).astuple() + a = relay.TupleGetItem(y, 0) + b = relay.TupleGetItem(y, 1) + c = relay.TupleGetItem(y, 2) + out = relay.Tuple([a, b]) + func = relay.Function([x], out) + x_data = np.random.rand(6, 9).astype("float32") + inputs = {"x": x_data} + output_list = generate_ref_data(func, inputs) + input_list = [inputs["x"]] + verify_source(func, input_list, output_list) + + +def test_mobilenet(): + mod, params = testing.mobilenet.get_workload(batch_size=1) + data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] + data = np.random.uniform(size=data_shape).astype("float32") + inputs = {"data": data} + output_list = generate_ref_data(mod, inputs, params) + input_list = [inputs["data"]] + verify_source(mod, input_list, output_list, params) + + +if __name__ == "__main__": + test_tuple_output() + test_mobilenet() + test_subtract() + test_mul_param() + test_id() + test_add_const() + test_tuple_getitem() + test_nested_tuples() + test_concatenate() + test_conv_with_params() + test_add_with_params() + test_conv2d() diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 68708aaeb413f..b14b9a2932fcf 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -133,7 +133,7 @@ def test_plan_memory(): storage_ids = set() device_types = set() for k, v in smap.items(): - assert len(v) == 2 + assert len(v) == 3 for x in v[0]: storage_ids.add(x.value) for x in v[1]: diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index ff68d489c7c54..c5dcd6edab9b7 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -266,7 +266,7 @@ def check_storage_and_device_types(): storage_ids = [] device_types = [] for _, storage_dev_type in smap.items(): - assert len(storage_dev_type) == 2 + assert len(storage_dev_type) == 3 for sid in storage_dev_type[0]: storage_ids.append(sid.value) for did in storage_dev_type[1]: diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 1bd24c931b723..f6bc6f21d3fc5 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -157,7 +157,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { with _make_session(workspace, factory.get_lib()) as sess: graph_mod = tvm.micro.create_local_graph_runtime( - factory.get_json(), sess.get_system_lib(), sess.context + factory.get_graph(), sess.get_system_lib(), sess.context ) A_data = tvm.nd.array(np.array([2, 3], dtype="uint8"), ctx=sess.context) assert (A_data.asnumpy() == np.array([2, 3])).all() diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index ffe859927ad7a..0955048d2e681 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -199,7 +199,7 @@ def test_llvm_link_params(): assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded assert mod.get_function("TVMSystemLibEntryPoint") != None - graph = json.loads(lib.graph_json) + graph = json.loads(lib.graph) for p in lib.params: _verify_linked_param(dtype, lib, mod, graph, p) or found_one @@ -310,7 +310,7 @@ def test_c_link_params(): lib_mod = tvm.runtime.load_module(lib_path) # lib_mod = lib_factory['default']() - graph = json.loads(lib.graph_json) + graph = json.loads(lib.graph) for p in lib.params: _verify_linked_param(dtype, lib, lib_mod, graph, p) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index c999091cc3cce..1067acc3fadc2 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -35,7 +35,7 @@ def validate_graph_json(extract_dir, factory): with open(os.path.join(extract_dir, "runtime-config", "graph", "graph.json")) as graph_f: graph_json = graph_f.read() - assert graph_json == factory.graph_json + assert graph_json == factory.graph # Just check it parses and looks roughly right. graph = json.loads(graph_json) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 930011d4fd333..1f9754650f55e 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -530,7 +530,7 @@ def test_debug_graph_runtime(): debug_g_mod = debug_runtime.GraphModuleDebug( complied_graph_lib["debug_create"]("default", ctx), [ctx], - complied_graph_lib.get_json(), + complied_graph_lib.get_graph(), None, ) debug_g_mod.set_input("data", data)