diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 7922e978c381..7a6cfa364447 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -287,6 +287,11 @@ TVM_DLL Pass LowerThreadAllreduce(); */ TVM_DLL Pass InferFragment(); +/*! + * \brief This annotation is for nodes to be disabled for builtin lowering + */ +static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin"; + /*! * \brief Lower builtin intrinsics. * \return The pass. diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index e51f1702773b..3b412cb646ca 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Codegen for Arm(R) Ethos(TM)-U""" +"""Codegen for Arm(R) Ethos(TM)-U NPU""" + import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir @@ -133,24 +134,6 @@ def transform_function( return OptimizeLUTs().visit(func) -@tvm._ffi.register_func("relay.ext.ethos-u") -def ethosu_compiler(external_function): - """The entry-point to a compile a external relay function of - NPU compatible operators to generated command stream. - Such generated command stream would be used to create c-source r - runtime module that interfaces with NPU driver. - """ - assert isinstance(external_function, tvm.ir.function.BaseFunc) - func_name = external_function.attrs["global_symbol"] - # There should only be a single input - assert len(external_function.params) == 1 - input_size = util.calculate_size_bytes(external_function.params[0]) - output_size = util.calculate_size_bytes(external_function.body) - cmms, encoded_constants, scratch_size = _compile(external_function) - ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethos-u.create") - return ethosu_runtime(func_name, cmms, encoded_constants, scratch_size, input_size, output_size) - - @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") def constant_updater(expr, symbol): # pylint: disable=unused-argument """ @@ -161,25 +144,25 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument return dict() -def _compile(ext_func): +@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func") +def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: """ - This is the main wrapper that accepts an external - relay function and runs all the passes to lower it down - to command stream + This is the hook for python-based lowering of relay function + that gets offloaded to the microNPU. + Parameters ---------- - ext_func : tvm.relay.function.Function - The partitioned relay function + ext_func : relay.Function + This is the partitioned relay function + Returns ------- - cs : str - An hex string of the bytes of command stream - encoded_constants : str - An hex string of the bytes that includes concat'd - encoded weights, encoded biases and scales. - scratch_size : int - The size of the scratch buffer needed. + primfunc : tir.PrimFunc + This returns the scheduled PrimFunc """ + assert len(ext_func.params) == 1 + input_size = util.calculate_size_bytes(ext_func.params[0]) + output_size = util.calculate_size_bytes(ext_func.body) mod = tvm.IRModule() mod["main"] = ext_func mod = LegalizeEthosU()(mod) @@ -189,6 +172,51 @@ def _compile(ext_func): # this should be a single intelligent and a composite scheduler # that can perform scheduling based on user inputs such as # scratch memory size. - tir_mod, params = lower_to_tir(mod["main"], copy_constants()) - cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(tir_mod, params) - return cmms, encoded_constants, scratch_size + tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants()) + + for idx in const_dict.keys(): + const_dict[idx] = tvm.nd.array(const_dict[idx]) + + primfunc = tir_mod["main"] + primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) + primfunc = primfunc.with_attr("ethos-u.constants", const_dict) + primfunc = primfunc.with_attr("ethos-u.input_size", input_size) + primfunc = primfunc.with_attr("ethos-u.output_size", output_size) + return primfunc + + +@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact") +def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact: + """ + This is the hook for python-based lowering of TIR PrimFunc + that has undergone unified optimization to compilation + artifact destined for the microNPU. + + Parameters + ---------- + primfunc : tir.PrimFunc + TIR PrimFunc that has undergone unified optimizations + + Returns + ------- + CompilationArtifact + This is a structure that holds the binary artifacts + for the microNPU + """ + symbol = str(primfunc.attrs["global_symbol"]) + const_dict = primfunc.attrs["ethos-u.constants"] + input_size = primfunc.attrs["ethos-u.input_size"] + output_size = primfunc.attrs["ethos-u.output_size"] + tir_mod = tvm.IRModule() + tir_mod[symbol] = primfunc + + const_dict_with_int_keys = dict() + for idx in const_dict.keys(): + const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy() + + cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate( + tir_mod, const_dict_with_int_keys + ) + return util.CompilationArtifact( + cmms, encoded_constants, scratch_size, input_size, output_size, symbol + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index b68a5ad14a6f..b3ffecb2ec22 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -21,7 +21,7 @@ from tvm.relay.expr_functor import ExprMutator from tvm.driver.build_module import schedule_to_module -from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants +from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants, AnnotateAllocates from .scheduler import schedule @@ -88,6 +88,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod, const_dict = EncodeConstants(const_dict)(mod) mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) + mod = AnnotateAllocates()(mod) return mod, const_dict diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index cb46ba319edd..41a6832c5953 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -488,3 +488,34 @@ def _encode_constants(mod): return new_func, new_const_dict return _encode_constants + + +# This need to be kept in sync with kDisableLowerTVMBuiltin in include/tvm/tir/transform.h +DISABLE_LOWER_BUILTIN = "disable_lower_builtin" + + +def AnnotateAllocates(): + """ + This is pass to annotate all allocate + nodes of the PrimFuncs of the microNPU + to be not lowered to built-ins. + """ + + def _post_transform(allocate): + return tvm.tir.Allocate( + buffer_var=allocate.buffer_var, + dtype=allocate.dtype, + extents=allocate.extents, + condition=allocate.condition, + body=allocate.body, + annotations={DISABLE_LOWER_BUILTIN: True}, + ) + + def _ftransform(f, mod, ctx): + return f.with_body( + tvm.tir.stmt_functor.ir_transform(f.body, None, _post_transform, ["tir.Allocate"]) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.annotate_allocates" + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index e1af7f1534e2..c8e3d34d1c29 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -173,16 +173,16 @@ def extract_buffer_info( primfunc = mod.functions.items()[0][1] for idx, const_data in param_dict.items(): param = primfunc.params[idx] - buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + buffer_info[param] = BufferInfo( const_data, const_data.shape, const_data.dtype, BufferType.constant ) for param in primfunc.params: - if primfunc.buffer_map[param].data not in buffer_info.keys(): - buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + if param not in buffer_info.keys(): + buffer_info[param] = BufferInfo( + None, + None, None, - primfunc.buffer_map[param].shape, - primfunc.buffer_map[param].dtype, BufferType.input_or_output, ) @@ -253,7 +253,7 @@ def replace_npu_fm_with_address(npu_fm): def replace_npu_address_range_with_address(npu_addr_range): assert isinstance(npu_addr_range.address, tvm.tir.Load) buffer = npu_addr_range.address.buffer_var - assert buffer in buffer_addresses.keys() + assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) @@ -299,11 +299,6 @@ def classify_io(buffer): size_in_bytes = util.round_up(size_in_bytes, 16) constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes)) else: - size_in_bytes = int( - (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape)) - ) - # Every memory address the NPU access have to be 16 byte aligned - size_in_bytes = util.round_up(size_in_bytes, 16) if info.btype == BufferType.input_or_output: buffer_type = classify_io(_buffer) assert buffer_type in (BufferType.input, BufferType.output) @@ -315,6 +310,11 @@ def classify_io(buffer): address = arch_config.lut_start_address buffer_addresses[_buffer] = (address, info.btype) else: + size_in_bytes = int( + (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape)) + ) + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) assert info.btype == BufferType.scratch address = scratch_size scratch_size += size_in_bytes diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 589ab21b3998..45a82d5932d6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -28,6 +28,9 @@ import tvm # type: ignore from tvm import relay +from tvm._ffi import register_object +from tvm.runtime import Object +from . import _ffi_api class QConv2DArgs(Enum): @@ -209,3 +212,30 @@ def calculate_size_bytes(expr): element_size = type_info.bits // 8 elements = np.prod(list(expr.checked_type.shape)) return element_size * elements + + +@register_object("relay.ext.ethos-u.CompilationArtifact") +class CompilationArtifact(Object): + """ + This is a structure to hold binary artifacts + for the microNPU. + """ + + def __init__( + self, + command_stream: str, + encoded_constants: str, + scratch_size: int, + input_size: int, + output_size: int, + function_name: str, + ): + self.__init_handle_by_constructor__( + _ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member + command_stream, + encoded_constants, + scratch_size, + input_size, + output_size, + function_name, + ) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 974523d1eb1a..7a55d3ef244e 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -68,9 +68,13 @@ def convert_to_object(value, span=None): if isinstance(value, dict): vlist = [] for item in value.items(): - if not isinstance(item[0], ObjectTypes) and not isinstance(item[0], string_types): + if ( + not isinstance(item[0], ObjectTypes) + and not isinstance(item[0], string_types) + and not isinstance(item[0], Number) + ): raise ValueError("key of map must already been a container type") - vlist.append(item[0]) + vlist.append(convert_to_object(item[0])) vlist.append(convert_to_object(item[1])) return _ffi_api.Map(*vlist) if isinstance(value, ObjectGeneric): diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc new file mode 100644 index 000000000000..d618a4971189 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/codegen.cc + * + * \brief This file contains the target hooks for Arm(R) Ethos(TM)-U NPU + * Codegen. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../op/make_op.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! + * \brief This mutator lowers each external + * relay function to a TIR PrimFunc + */ +class RelayToTIRMutator : public MixedModeMutator { + public: + explicit RelayToTIRMutator(IRModule ir_module) : ir_module_(ir_module) {} + + IRModule operator()() { + GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); + Function main_func = Downcast(ir_module_->Lookup(main_global_var)); + + // Copy everything across and mutate the body + Function mutated_main = + Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, + main_func->type_params, main_func->attrs, main_func->span); + + ir_module_->Update(main_global_var, mutated_main); + ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_); + return ir_module_; + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + Call call = Downcast(post); + if (call->op->IsInstance()) { + Function func = Downcast(call->op); + auto codegen_name = func->GetAttr(attr::kCompiler); + if (codegen_name.defined() && codegen_name == "ethos-u") { + auto relay_to_tir_func_pf = + tvm::runtime::Registry::Get("relay.ext.ethos-u.relay_to_tir_func"); + ICHECK(relay_to_tir_func_pf); + tir::PrimFunc prim_func = (*relay_to_tir_func_pf)(func); + prim_func = WithAttr(prim_func, tvm::attr::kTarget, Target("ethos-u")); + String symbol_name = prim_func->GetAttr(tvm::attr::kGlobalSymbol).value(); + GlobalVar gv(symbol_name); + Array args = call->args; + gv->checked_type_ = func->checked_type(); + ir_module_->Update(gv, prim_func); + device_contexts_.Set(gv, codegen_name.value()); + return Call(gv, args, call->attrs, call->type_args); + } + } + return post; + } + + private: + IRModule ir_module_; + Map device_contexts_; +}; + +tvm::transform::Pass RelayToTIR() { + runtime::TypedPackedFunc pass_func = + [=](IRModule ir_module, transform::PassContext pass_context) { + return RelayToTIRMutator(ir_module)(); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.ethos-u.RelayToTIR", {}); +} + +/*! + * \brief This function lowers the IRModule with PrimFunc + * with the target of the microNPU to a C-source runtime module + */ +runtime::Module TIRToRuntime(IRModule mod, Target target) { + Array compile_artifacts; + for (const auto& kv : mod->functions) { + const tir::PrimFunc& prim_func = Downcast(kv.second); + Optional> params = + prim_func->GetAttr>("ethos-u.constants"); + ICHECK(params) << "microNPU params should be present"; + auto primfunc_to_artifact_pf = + tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact"); + ICHECK(primfunc_to_artifact_pf); + CompilationArtifact ca = (*primfunc_to_artifact_pf)(prim_func); + compile_artifacts.push_back(ca); + } + auto ca_to_runtime = tvm::runtime::Registry::Get("runtime.module.ethos-u.create"); + return (*ca_to_runtime)(compile_artifacts); +} + +TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU) + .set_attr("use_device_api", Bool(true)) + .set_attr("RelayToTIR", RelayToTIR()) + .set_attr("TIRToRuntime", TIRToRuntime); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index b7b359ab4735..f56544aee99a 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -41,34 +41,35 @@ #include #include "../../../../runtime/file_utils.h" +#include "utils.h" namespace tvm { namespace runtime { +using CompilationArtifact = relay::contrib::ethosu::CompilationArtifact; + // The runtime.Module that contains the host-side c code // required for invoking the NPU with the command stream class EthosUModuleNode : public ModuleNode { public: /*! - * \brief The ethos runtime module. + * \brief The microNPU runtime module. * - * \param func_name_ name of the should be codegen'd function - * \param cmms_hex_ command stream for the NPU in hex - * \param weights_bias_hex_ the encoded biases and weights for the NPU in hex - * \param scratch_size_ the size of the scratch memory required for command stream - * \param input_size_ the size (in bytes) for the input tensor - * \param output_size_ the size (in bytes) for the output tensor + * \param compilation_artifacts + * This is an array of CompilationArtifacts that is produced via + * lowering each PrimFunc to command stream. Here, those artifacts + * will be used to create the c-source. */ - explicit EthosUModuleNode(const String& func_name_, const String& cmms_hex_, - const String& weights_bias_hex_, const Integer& scratch_size_, - const Integer& input_size_, const Integer& output_size_) { - func_name = func_name_; - cmms_hex = std::move(cmms_hex_); - weights_bias_hex = std::move(weights_bias_hex_); - scratch_size = scratch_size_->value; - input_size = input_size_->value; - output_size = output_size_->value; - c_source = GenerateSource(); + explicit EthosUModuleNode(Array compilation_artifacts) + : compilation_artifacts_(compilation_artifacts) { + c_source += "#include \n"; + c_source += "#include \n"; + c_source += "#include \n"; + c_source += "#include \n\n"; + for (const CompilationArtifact& compilation_artifact : compilation_artifacts) { + c_source += GenerateSource(compilation_artifact); + c_source += "\n\n"; + } } /*! @@ -79,7 +80,6 @@ class EthosUModuleNode : public ModuleNode { */ void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - LOG(INFO) << "format=" << fmt << ";;\n"; ICHECK_EQ(fmt, "c") << "Can only save to format=" << "c"; std::ofstream out(file_name); @@ -89,7 +89,7 @@ class EthosUModuleNode : public ModuleNode { std::string GetSource(const std::string& format) final { return c_source; } - std::string GetCS() { return cmms_hex; } + Array GetArtifacts() { return compilation_artifacts_; } /*! * \brief Get a PackedFunc from the module. @@ -102,7 +102,11 @@ class EthosUModuleNode : public ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_func_names") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = Array{this->func_name}; + Array func_names; + for (const CompilationArtifact& ca : compilation_artifacts_) { + func_names.push_back(ca->function_name); + } + *rv = func_names; }); } return PackedFunc(); @@ -110,21 +114,14 @@ class EthosUModuleNode : public ModuleNode { const char* type_key() const override { return "c"; } - static Module Create(String func_name, String cmms_hex, String weights_bias_hex, - Integer scratch_size, Integer input_size, Integer output_size) { - auto n = make_object(func_name, cmms_hex, weights_bias_hex, scratch_size, - input_size, output_size); + static Module Create(Array compilation_artifacts) { + auto n = make_object(compilation_artifacts); return Module(n); } private: - String c_source; - String func_name; - String cmms_hex; - String weights_bias_hex; - size_t scratch_size; - size_t input_size; - size_t output_size; + std::string c_source; + Array compilation_artifacts_; int indent_{0}; /*! @@ -151,10 +148,10 @@ class EthosUModuleNode : public ModuleNode { * \return string of code that updates the base_addrs array with the base address of the given * array */ - std::string SetBaseAddress(int index, std::string name) { + std::string SetBaseAddress(int index, std::string name, std::string size) { std::stringstream ss; ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n"; - ss << " base_addrs_size[" << index << "] = " << name << "_size;\n"; + ss << " base_addrs_size[" << index << "] = " << size << ";\n"; return ss.str(); } @@ -211,43 +208,39 @@ class EthosUModuleNode : public ModuleNode { * * \return string of code that offloads a subgraph to the NPU */ - std::string GenerateSource() { - std::string func_no_dashes = func_name; + std::string GenerateSource(relay::contrib::ethosu::CompilationArtifact compilation_artifact) { + std::string func_no_dashes = compilation_artifact->function_name; std::replace(func_no_dashes.begin(), func_no_dashes.end(), '-', '_'); std::stringstream ss; - ss << "#include \n"; - ss << "#include \n"; - ss << "#include \n"; - ss << "#include \n"; - ss << "\n"; - size_t weights_size = (weights_bias_hex.size() / 2); - ss << "static const size_t weights_size = " << std::to_string(weights_size) << ";\n"; - ss << "static const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; + size_t weights_size = (compilation_artifact->encoded_constants.size() / 2); + size_t scratch_size = compilation_artifact->scratch_size; ss << "// Update linker script to place .rodata.tvm in memory that can be accessed by the " "NPU\n"; if (weights_size > 0) { - ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t weights[" - << weights_size << "] = \""; - ss << GetHexString(weights_bias_hex); + ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t " + << func_no_dashes << "_weights[" << weights_size << "] = \""; + ss << GetHexString(compilation_artifact->encoded_constants); ss << "\";\n"; } else { - ss << "static int8_t* weights = NULL;\n"; + ss << "static int8_t* " << func_no_dashes << "_weights = NULL;\n"; } - ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t cms_data_data[" - << cmms_hex.size() / 2 << "] = \""; - ss << GetHexString(cmms_hex); + ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t " << func_no_dashes + << "_cms_data_data[" << compilation_artifact->command_stream.size() / 2 << "] = \""; + ss << GetHexString(compilation_artifact->command_stream); ss << "\";\n"; - ss << "static const size_t cms_data_size = sizeof(cms_data_data);\n"; ss << "\n"; PrintExternCPrefix(ss); ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, " << "size_t in0_size, int8_t* out0, size_t out0_size, void* resource_handle) {\n"; ss << " int num_tensors = 5;\n"; - ss << " void* cms_data = (void*)(cms_data_data);\n"; + ss << " void* cms_data = (void*)(" << func_no_dashes << "_cms_data_data);\n"; ss << " int64_t device_type = kDLCPU;\n"; ss << " int64_t device_id = 0;\n"; + ss << " const size_t weights_size = " << std::to_string(weights_size) << ";\n"; + ss << " const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; + ss << " const size_t cms_data_size = sizeof(" << func_no_dashes << "_cms_data_data);\n"; if (scratch_size > 0) { ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " "(uint64_t)scratch_size, 0, 16);\n"; @@ -257,11 +250,11 @@ class EthosUModuleNode : public ModuleNode { ss << " size_t base_addrs_size[num_tensors];\n"; ss << " uint64_t base_addrs[num_tensors];\n"; ss << "\n"; - ss << SetBaseAddress(0, "weights"); - ss << SetBaseAddress(1, "scratch"); - ss << SetBaseAddress(2, "scratch"); - ss << SetBaseAddress(3, "in0"); - ss << SetBaseAddress(4, "out0"); + ss << SetBaseAddress(0, func_no_dashes + "_weights", "weights_size"); + ss << SetBaseAddress(1, "scratch", "scratch_size"); + ss << SetBaseAddress(2, "scratch", "scratch_size"); + ss << SetBaseAddress(3, "in0", "in0_size"); + ss << SetBaseAddress(4, "out0", "out0_size"); ss << "\n"; ss << " int32_t result = TVMEthosULaunch(resource_handle, cms_data, cms_data_size, " "base_addrs, base_addrs_size, num_tensors);\n"; @@ -277,8 +270,8 @@ class EthosUModuleNode : public ModuleNode { ss << "// Wrapper function is provided to allow for easier debugging\n"; ss << "inline static int32_t " + func_no_dashes + "_wrapper_(void* input, void* output, void* resource_handle) {\n"; - ss << " size_t input_data_size = " << input_size << ";\n"; - ss << " size_t output_data_size = " << output_size << ";\n"; + ss << " size_t input_data_size = " << compilation_artifact->input_size << ";\n"; + ss << " size_t output_data_size = " << compilation_artifact->output_size << ";\n"; ss << " return " + func_no_dashes + "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size, " + "resource_handle);\n"; @@ -286,7 +279,7 @@ class EthosUModuleNode : public ModuleNode { PrintExternCPostfix(ss); ss << "\n"; PrintExternCPrefix(ss); - PrintRuntimeFunctionHeader(ss, func_name); + PrintRuntimeFunctionHeader(ss, func_no_dashes); EnterScope(); PrintIndents(ss); ss << "return " << func_no_dashes << "_wrapper_(input, output, resource_handle);\n"; @@ -313,14 +306,12 @@ inline EthosUModuleNode* EthosUModule::operator->() { } TVM_REGISTER_GLOBAL("runtime.module.ethos-u.create") - .set_body_typed([](String func_name, String cmms_hex, String weights_bias_hex, - Integer scratch_size, Integer input_size, Integer output_size) { - return EthosUModuleNode::Create(func_name, cmms_hex, weights_bias_hex, scratch_size, - input_size, output_size); + .set_body_typed([](Array compilation_artifacts) { + return EthosUModuleNode::Create(compilation_artifacts); }); -TVM_REGISTER_GLOBAL("runtime.module.ethos-u.getcs").set_body_typed([](EthosUModule mod) { - return mod->GetCS(); +TVM_REGISTER_GLOBAL("runtime.module.ethos-u.get_artifacts").set_body_typed([](EthosUModule mod) { + return mod->GetArtifacts(); }); } // namespace runtime diff --git a/src/relay/backend/contrib/ethosu/utils.cc b/src/relay/backend/contrib/ethosu/utils.cc new file mode 100644 index 000000000000..7e6c1c2ac840 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/utils.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/utils.cc + * \brief Utilities for microNPU codegen + */ + +#include "utils.h" + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +CompilationArtifact::CompilationArtifact(String command_stream, String encoded_constants, + Integer scratch_size, Integer input_size, + Integer output_size, String function_name) { + auto compilation_artifact_node = make_object(); + compilation_artifact_node->command_stream = command_stream; + compilation_artifact_node->encoded_constants = encoded_constants; + compilation_artifact_node->scratch_size = scratch_size; + compilation_artifact_node->input_size = input_size; + compilation_artifact_node->output_size = output_size; + compilation_artifact_node->function_name = function_name; + data_ = std::move(compilation_artifact_node); +} + +TVM_REGISTER_NODE_TYPE(CompilationArtifactNode); +TVM_REGISTER_GLOBAL("relay.ext.ethos-u.CompilationArtifact") + .set_body_typed([](String command_stream, String encoded_constants, Integer scratch_size, + Integer input_size, Integer output_size, String function_name) { + return CompilationArtifact(command_stream, encoded_constants, scratch_size, input_size, + output_size, function_name); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CompilationArtifactNode(\n" + << "command_stream=" << node->command_stream + << ",\n encoded_constants=" << node->encoded_constants + << ",\n scratch_size=" << node->scratch_size + << ",\n input_size=" << node->input_size + << ",\n output_size=" << node->output_size + << ",\n function_name=" << node->function_name << ")"; + }); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/utils.h b/src/relay/backend/contrib/ethosu/utils.h new file mode 100644 index 000000000000..5e9e337c3f17 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/utils.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/utils.h + * \brief Utilities for microNPU codegen + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_ +#define TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! + * \brief Captures all the binary artifactes required to create + * the C-source runtime module + */ +struct CompilationArtifactNode : public Object { + /*! \brief The binary command stream (CS) in hex format */ + String command_stream; + /*! \brief The encoded biases and weights in hex format */ + String encoded_constants; + /*! \brief The intermediary scratch area required for the execution of the CS */ + Integer scratch_size; + /*! \brief The size of the input tensor in bytes */ + Integer input_size; + /*! \brief The size of the output tensor in bytes */ + Integer output_size; + /*! \brief The name of the function */ + String function_name; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("command_stream", &command_stream); + v->Visit("encoded_constants", &encoded_constants); + v->Visit("scratch_size", &scratch_size); + v->Visit("input_size", &input_size); + v->Visit("output_size", &output_size); + v->Visit("function_name", &function_name); + } + + bool SEqualReduce(const CompilationArtifactNode* other, SEqualReducer equal) const { + return equal(command_stream, other->command_stream) && + equal(encoded_constants, other->encoded_constants) && + equal(scratch_size, other->scratch_size) && equal(input_size, other->input_size) && + equal(output_size, other->output_size) && equal(function_name, other->function_name); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(command_stream); + hash_reduce(encoded_constants); + hash_reduce(scratch_size); + hash_reduce(input_size); + hash_reduce(output_size); + hash_reduce(function_name); + } + + static constexpr const char* _type_key = "relay.ext.ethos-u.CompilationArtifact"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompilationArtifactNode, Object); +}; + +class CompilationArtifact : public ObjectRef { + public: + TVM_DLL CompilationArtifact(String command_stream, String encoded_constants, Integer scratch_size, + Integer input_size, Integer output_size, String function_name); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CompilationArtifact, ObjectRef, CompilationArtifactNode); +}; + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_ diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index 75b161ad4499..6f1914eac4c3 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -33,6 +33,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target); } // namespace relay TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) + .set_attr("use_device_api", Bool(true)) .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 32c896610fdc..b47bc401b37f 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -228,6 +228,9 @@ class TECompilerImpl : public TECompilerNode { } Map GetDeviceContexts() { return device_contexts_; } + void SetDeviceContexts(const Map& device_contexts) { + device_contexts_ = device_contexts; + } void Clear() final { cache_.clear(); } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index d1094122d39a..60dd5fe2c6b3 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -126,6 +126,7 @@ class TECompilerNode : public Object { * annotated) */ virtual Map GetDeviceContexts() = 0; + virtual void SetDeviceContexts(const Map& device_contexts) = 0; virtual Map GetOpWeights() const = 0; diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 8246d61579f2..5540c35a8f7e 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -402,7 +402,6 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("devices"); -TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU).set_attr("use_device_api", Bool(true)); /********** Registry **********/ diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 3343e1062e57..a5ecf4ba8296 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -117,6 +117,11 @@ class BuiltinLower : public StmtExprMutator { // and less than runtime::kMaxStackAlloca heuristic // they are not serviced with TVMBackendWorkspaceAlloc calls // to be placed on stack. + if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) { + if (Downcast(op->annotations[transform::kDisableLowerTVMBuiltin])) { + return stmt; + } + } if (device_type_.defined()) { if (const auto* dev_type = device_type_.as()) { auto storage_scope = Downcast(op->buffer_var->type_annotation)->storage_scope; diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 6e8793fbd367..169983a525df 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -64,31 +64,21 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { // Collect variables and buffers to map between Array args; std::vector> var_def; - std::vector> buffer_def; + bool buffer_map_found = false; for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; - Var v_arg = Var("arg" + std::to_string(i), param->dtype); auto it = func_ptr->buffer_map.find(param); if (it != func_ptr->buffer_map.end()) { - buffer_def.emplace_back(v_arg, (*it).second); + args.push_back((*it).second->data); + buffer_map_found = true; } else { - var_def.emplace_back(v_arg, param); + args.push_back(param); } - - args.push_back(v_arg); - } - - // Bind variables then bind buffers to them to ensure correct ordering - for (const auto& kv : var_def) { - binder.Bind(kv.second, kv.first, kv.first->name_hint, true); - } - for (const auto& kv : buffer_def) { - binder.Bind(kv.second->data, kv.first, kv.first->name_hint, true); } - if (buffer_def.size()) { + if (buffer_map_found) { device_init.push_back(AttrStmt(node, attr::device_id, device_id, nop)); device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); } diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index e20ab41cb576..f4393d409d04 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -154,14 +154,14 @@ def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_ ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = ( + compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] + ) # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -241,15 +241,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -328,15 +325,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -423,15 +417,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -501,15 +492,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -551,15 +539,12 @@ def create_relay_graph(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -608,15 +593,12 @@ def create_model(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -705,18 +687,16 @@ def rounding_right_shift(lhs, rhs): [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)] ).astype(ofm_dtype) - compiled_model = infra.build_source(mod, input_data, [output_data], accel_type) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + compiled_models = infra.build_source(mod, input_data, [output_data], accel_type) + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) - infra.verify_source(compiled_model, accel_type) + infra.verify_source(compiled_models, accel_type) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -738,15 +718,13 @@ def test_ethosu_identity_codegen(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp mod, {"ifm": in_data}, [out_data], accel_type, output_tolerance=1 ) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) @@ -786,15 +764,13 @@ def test_relay_reshape_codegen(ifm_shape, new_shape, accel_type): accel_type, ) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) @@ -831,15 +807,13 @@ def test_relay_strided_slice_codegen(ifm_shape, begin, end, accel_type): accel_type, ) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) @@ -907,15 +881,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -957,16 +928,18 @@ def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtyp ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source source = ethosu_module.get_source() assert ( - '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t cms_data_data' in source + '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_cms_data_data' + in source + ) + assert ( + '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_weights' + in source ) - assert '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t weights' in source @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -990,15 +963,13 @@ def clz_comp(n): compiled_model = infra.build_source(mod, {"ifm": in_data}, [out_data], accel_type) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) @@ -1057,15 +1028,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 91cee81a1565..de8a7f922390 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -48,8 +48,8 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_7 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([128], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -122,7 +122,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 592, 12, T.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 160, 12, T.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -190,9 +190,9 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_8 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_9 = T.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") - placeholder_global = T.allocate([80], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True}) + placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_6.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) @@ -312,6 +312,10 @@ def get_graph(): # More generally, check compiles successfully to make sure # nothing else was overrwritten. + # With Target Hooks the TIR module needs a target attached + # and lowered via make unpacked API. + tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) tir_to_cs_translator.translate(tir_mod, params) diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py index d32b441fd2eb..9485b4f69520 100644 --- a/tests/python/contrib/test_ethosu/test_lookup_table.py +++ b/tests/python/contrib/test_ethosu/test_lookup_table.py @@ -103,16 +103,13 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) - infra.verify_source(compiled_models, accel_type) @@ -162,16 +159,13 @@ def test_random_lut(accel_type): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) - infra.verify_source(compiled_models, accel_type) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 1d3afec30cbc..2f2cd7a483db 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -255,7 +255,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1024], "int8", "global") + ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -276,7 +276,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1536], "int8", "global") + ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -297,7 +297,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_3 = T.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2560], "int8", "global") + ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -320,7 +320,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2304], "int8", "global") + ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index b1f923de4646..cce414c4c8f7 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -39,8 +39,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_1 = T.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([304], "uint8", "global") - placeholder_d_global = T.allocate([80], "uint8", "global") + placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -87,8 +87,8 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8") buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8") # body - placeholder_global = T.allocate([416], "uint8", "global") - placeholder_d_global = T.allocate([112], "uint8", "global") + placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 94c8f0ddc04e..59b7b2c21723 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -233,9 +233,12 @@ def test_buffer_info_extraction(): }, ] for test_case in test_cases: - buffer_info = tir_to_cs_translator.extract_buffer_info( - test_case["tir_module"], test_case["param_dict"] - ) + # With Target Hooks the TIR module needs a target attached + # and lowered via make unpacked API. + tir_mod = test_case["tir_module"] + tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) + buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) for buffer_var, info in buffer_info.items(): buffer_name = buffer_var.name if buffer_name in test_case["constants"].keys(): @@ -247,8 +250,6 @@ def test_buffer_info_extraction(): ) info.btype == tir_to_cs_translator.BufferType.constant else: - assert list(info.shape) == test_case["data_buffers"][buffer_name][0] - assert info.dtype == test_case["data_buffers"][buffer_name][1] assert info.btype == test_case["data_buffers"][buffer_name][2] @@ -831,10 +832,11 @@ def check_buffer(address, region, length, buffer_var): ) for test_case in test_cases: - buffer_info = tir_to_cs_translator.extract_buffer_info( - test_case["tir_module"], test_case["param_dict"] - ) - extern_calls = extract_call_extern_list(test_case["tir_module"]) + tir_mod = test_case["tir_module"] + tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) + buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) + extern_calls = extract_call_extern_list(tir_mod) _npu_ops = list() for extern_call in extern_calls: _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 95defff4681d..73f3a0f27eba 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -483,8 +483,8 @@ def test_compile_tflite_module_with_external_codegen_ethosu( tvmc.compiler.compile_model( tvmc_model, target=f"ethos-u -accelerator_config={accel_type}, c -mcpu=cortex-m55", - runtime=Runtime("crt", {"system-lib": True}), - executor=Executor("aot"), + runtime=Runtime("crt"), + executor=Executor("aot", {"unpacked-api": True}), output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 3de4fecf5544..473b8d5ee300 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -92,7 +92,7 @@ def compile_to_main_func(interface_api="c", use_unpacked_api=True): workspace_byte_alignment=16, pass_config=test_runner.pass_config, ) - main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0] + main_ir_module = compiled_models[0].executor_factory.lowered_ir_mods.items()[1][1] main_func = main_ir_module["run_model"] return main_func @@ -177,6 +177,9 @@ def test_device_api_hooks_unpacked_api(device_api_main_func): ) +@pytest.mark.skip( + "Skipping this test as this is incorrectly using Arm(R) Ethos(TM)-U NPU with packed calling convention which is not supported by the NPU codegen's TIR to Runtime Hook. We need to use a different target to test this feature" +) def test_device_api_hooks_packed_api(device_api_main_func): """Check for Device API hooks with packed internal calls""" main_func = device_api_main_func(interface_api="packed", use_unpacked_api=False) diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 9d917466758b..e5f41e7b520f 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -58,7 +58,7 @@ def test_device_setup(mod, target, dev): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 - assert f.params[0].name == "arg0" + assert f.params[0].name == "A" assert f.body.node == "default" assert f.body.attr_key == "device_id" assert f.body.value == 0 @@ -77,16 +77,13 @@ def test_no_buffers_no_device_setup(): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 - assert f.body.var.name == "A" - assert f.body.value.name == "arg0" + assert f.params[0].name == "A" def test_argument_mapping(mod): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 - assert f.params[0].name == "arg0" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg0" + assert f.params[0].name == "A" def test_argument_mapping_multiple(): @@ -101,12 +98,8 @@ def test_argument_mapping_multiple(): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 2 - assert f.params[0].name == "arg0" - assert f.params[1].name == "arg1" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg0" - assert f.body.body.body.body.var.name == "B" - assert f.body.body.body.body.value.name == "arg1" + assert f.params[0].name == "A" + assert f.params[1].name == "B" def test_argument_mapping_multiple_matching(): @@ -120,12 +113,8 @@ def test_argument_mapping_multiple_matching(): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 2 - assert f.params[0].name == "arg0" - assert f.params[1].name == "arg1" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg0" - assert f.body.body.body.body.condition.a.name == "A" - assert f.body.body.body.body.condition.b.name == "arg1" + assert f.params[0].name == "A" + assert f.params[1].name == "A" def test_body(): @@ -140,15 +129,9 @@ def test_body(): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 3 - assert f.params[0].name == "arg0" - assert f.params[1].name == "arg1" - assert f.params[2].name == "arg2" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg2" - assert f.body.body.body.body.var.name == "B" - assert f.body.body.body.body.value.name == "arg1" - assert f.body.body.body.body.body.condition.a.name == "A" - assert f.body.body.body.body.body.condition.b.name == "arg0" + assert f.params[0].name == "A" + assert f.params[1].name == "B" + assert f.params[2].name == "A" if __name__ == "__main__":