From 5dc25afc873fd2c1cecabbf87515f29d459986e2 Mon Sep 17 00:00:00 2001 From: Aleksei-grovety <113356454+Aleksei-grovety@users.noreply.github.com> Date: Tue, 4 Jul 2023 11:34:11 +0300 Subject: [PATCH] [microNPU][ETHOSU] Add Vela's logic to select configuration block (#15186) For the case when cascader is enabled, the logic of choosing the optimal configuration block from TVM will be used in other cases, the Vela's logic will be used except the cases when dev_force_block_config option is specified. --- .../relay/backend/contrib/ethosu/vela_api.py | 86 ++++++++++++++++++- .../contrib/test_ethosu/test_networks.py | 4 +- .../test_ethosu/test_replace_conv2d.py | 14 +-- .../contrib/test_ethosu/test_vela_api.py | 50 +++++++++++ 4 files changed, 145 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 45c232a4610b..22f5cdd83b04 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -27,7 +27,12 @@ import numpy as np # type: ignore from ethosu.vela import api as vapi # type: ignore +from ethosu.vela.architecture_allocator import find_block_config from ethosu.vela.architecture_features import Accelerator, create_default_arch +from ethosu.vela.operation import NpuBlockType +from ethosu.vela.register_command_stream_generator import resampling_mode_map +from ethosu.vela.register_command_stream_util import to_kernel +from ethosu.vela.shape4d import Shape4D import tvm from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs @@ -56,6 +61,9 @@ def get_optimal_block_config( Therefore, we need to pick an optimal block configuration considering bandwidth to bring IFM blocks and the number of OFM block computes need to happen to cover the OFM as indicated by the npu op. + For the case when cascader is enabled, the logic of choosing the optimal configuration block + from TVM will be used in other cases, the Vela's logic will be used except + the cases when dev_force_block_config option is specified. Parameters ---------- @@ -73,8 +81,82 @@ def get_optimal_block_config( if options and options.dev_force_block_config: block_config = [int(v) for v in options.dev_force_block_config.split("x")] return vapi.NpuShape3D(height=block_config[0], width=block_config[1], depth=block_config[2]) - all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_config) - return _get_optimal_block_config(all_valid_block_configs) + elif options and options.enable_cascader: + all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_config) + return _get_optimal_block_config(all_valid_block_configs) + else: + return _find_block_config_with_vela(npu_op, accel_config) + + +def _find_block_config_with_vela( + npu_op: vapi.NpuOperation, accelerator: vapi.NpuAccelerator +) -> vapi.NpuShape3D: + """An internal function to get block config using Vela's logic. + + Parameters + ---------- + npu_op : ethosu.vela.api.NpuOperation + The NPU operation + accelerator : ethosu.vela.api.NpuAccelerator + The NPU accelerator + + Returns + ------- + ethosu.vela.api.NpuShape3D : + The optimal block config for the operator + """ + if isinstance(npu_op, vapi.NpuConv2DOperation): + block_type = NpuBlockType.ConvolutionMxN + elif isinstance(npu_op, vapi.NpuConvDepthWiseOperation): + block_type = NpuBlockType.ConvolutionDepthWise + elif isinstance(npu_op, vapi.NpuPoolingOperation): + block_type = ( + NpuBlockType.ReduceSum + if npu_op.sub_op_type == vapi.NpuPoolingOp.REDUCE_SUM + else NpuBlockType.Pooling + ) + elif isinstance(npu_op, vapi.NpuElementWiseOperation): + block_type = NpuBlockType.ElementWise + else: + assert 0, "Unsupported operation" + + ifm_shape = Shape4D(1, npu_op.ifm.shape.height, npu_op.ifm.shape.width, npu_op.ifm.shape.depth) + ifm2_shape = None + if npu_op.ifm2: + ifm2_shape = Shape4D( + 1, npu_op.ifm2.shape.height, npu_op.ifm2.shape.width, npu_op.ifm2.shape.depth + ) + ofm_shape = Shape4D(1, npu_op.ofm.shape.height, npu_op.ofm.shape.width, npu_op.ofm.shape.depth) + + ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] + ifm_bits = npu_op.ifm.data_type.size_in_bits() + lut_banks = 0 + if npu_op.activation: + lut_banks = 2 if npu_op.activation.op_type == vapi.NpuActivationOp.TABLE_LOOKUP else 0 + + has_scaling = True + for tensor in [npu_op.ifm, npu_op.ifm2, npu_op.ofm]: + if tensor and tensor.quantization is None: + has_scaling = False + break + + arch = create_default_arch(Accelerator.from_npu_accelerator(accelerator)) + + cfg = find_block_config( + arch, + block_type, + ofm_shape, + ifm_shape, + ifm2_shape, + npu_op.ifm2_scalar is not None, + ifm_bits, + to_kernel(npu_op.kernel), + lut_banks, + has_scaling, + ifm_resampling_mode, + ) + assert cfg is not None, f"There is no configuration suitable for {accelerator}" + return vapi.NpuShape3D(cfg.ofm_block.height, cfg.ofm_block.width, cfg.ofm_block.depth) def _get_optimal_block_config(all_valid_block_configs: List[vapi.NpuShape3D]) -> vapi.NpuShape3D: diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index a5490cbe2b1c..308c06f50456 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -44,9 +44,9 @@ @pytest.mark.parametrize( "accel_type, model_url, workspace_size", [ - ("ethos-u65-256", MOBILENET_V1_URL, 2338848), + ("ethos-u65-256", MOBILENET_V1_URL, 2338864), ("ethos-u65-256", MOBILENET_V2_URL, 2264320), - ("ethos-u55-256", MOBILENET_V1_URL, 1793376), + ("ethos-u55-256", MOBILENET_V1_URL, 1793392), ("ethos-u55-256", MOBILENET_V2_URL, 2217152), ("ethos-u55-128", MOBILENET_V2_URL, 2217152), ("ethos-u55-64", MOBILENET_V2_URL, 2217152), diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 6bcea7008c86..32d1303e124e 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -633,11 +633,15 @@ def _get_func( reference_mod = trial[0] params = trial[1:] - func = _get_func(*params[:-1]) - mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1])) - script = mod.script() - mod = tvm.script.from_source(script) - tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) + config = { + "enable_cascader": True, + } + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethos-u.options": config}): + func = _get_func(*params[:-1]) + mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1])) + script = mod.script() + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) # fmt: off diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index 9f95e4b70925..16785e182a49 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -222,6 +222,28 @@ def main( __tvm_meta__ = None +# fmt: off +@tvm.script.ir_module +class Module3: + @T.prim_func + def main(ethos_u_0_i0: T.Buffer((1, 299, 299, 2), "int8"), ethosu_write: T.Buffer((1, 299, 299, 3), "int8")): + T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)}) + p2_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + ax0_ax1_fused_ax2_fused_ax3_fused = T.int32() + p2_global_1 = T.Buffer((128,), "uint8", data=p2_global) + with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, "DataPar", ""), "pragma_compute_cycles_hint", 1056): + p1_encoded = T.Buffer((128,), "uint8") + T.call_extern("handle", "ethosu_copy", p1_encoded[0], 128, p2_global_1[0]) + nn = T.int32() + T.attr(T.iter_var(nn, None, "DataPar", ""), "pragma_compute_cycles_hint", T.int64(179570)) + ethos_u_0_i0_1 = T.Buffer((178802,), "int8", data=ethos_u_0_i0.data) + ethosu_write_1 = T.Buffer((268203,), "int8", data=ethosu_write.data) + T.call_extern("handle", "ethosu_conv2d", "int8", 299, 299, 2, 299, 0, 299, ethos_u_0_i0_1[0], 0, 0, 0, T.float32(0.0039215683937072754), -128, "NHWC", 598, 2, 1, "int8", 299, 299, 3, 299, 0, 299, ethosu_write_1[0], 0, 0, 0, T.float32(0.025585981085896492), -128, "NHWC", 897, 3, 1, 2, 3, 1, 1, 1, 2, p2_global_1[0], 96, T.int8(-1), T.int8(-1), 0, p2_global_1[96], 32, T.int8(-1), T.int8(-1), 2, 0, 2, 1, "NONE", 0, 0, "TFL", "NONE", 32, 12, 8) + + __tvm_meta__ = None +# fmt: on + + def test_get_optimal_block_config(): block_configs_cases = [ { @@ -559,5 +581,33 @@ def verify(test_vec, mock_enc_w): verify(_test_vec, _mock_enc_w) +def test_find_block_config_with_vela(): + block_configs_cases = [ + { + "accel_type": vapi.NpuAccelerator.Ethos_U55_256, + "ref": vapi.NpuShape3D(30, 12, 8), + }, + { + "accel_type": vapi.NpuAccelerator.Ethos_U55_128, + "ref": vapi.NpuShape3D(17, 10, 8), + }, + { + "accel_type": vapi.NpuAccelerator.Ethos_U55_64, + "ref": vapi.NpuShape3D(25, 5, 8), + }, + { + "accel_type": vapi.NpuAccelerator.Ethos_U55_32, + "ref": vapi.NpuShape3D(25, 5, 4), + }, + ] + + mod = Module3 + ethosu_conv2d_call = mod["main"].body.body.seq[1].body.value + npu_op, _ = tirtocs.translate_ethosu_conv2d(ethosu_conv2d_call) + + for case in block_configs_cases: + assert vela_api._find_block_config_with_vela(npu_op, case["accel_type"]) == case["ref"] + + if __name__ == "__main__": tvm.testing.main()