From 4f6b478336042583d369cf7489baa671cde25f3f Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Tue, 12 Oct 2021 21:50:34 +0100 Subject: [PATCH] Address review comments on Arm(R) Ethos(TM)-U PR 3/6 (#9159) * Address review comments on Arm(R) Ethos(TM)-U PR 3/6 Change-Id: I22961885a503be31f6a72622ae0b5f874cc6f463 * Fix rebasing error Change-Id: I3e2fde786096ea331fcb366080fa779ec4ea4a5d * Fix more rebasing problems Change-Id: I1026e3ccee33a3fdec9ebbf6456bae244ad4f1d5 --- .../backend/contrib/ethosu/tir/compiler.py | 20 +- .../backend/contrib/ethosu/tir/convolution.py | 2 +- .../relay/backend/contrib/ethosu/tir/dma.py | 2 +- .../backend/contrib/ethosu/tir/passes.py | 6 +- .../backend/contrib/ethosu/tir/scheduler.py | 36 +-- .../backend/contrib/ethosu/tir/transform.py | 2 +- .../relay/backend/contrib/ethosu/tir/utils.py | 2 +- .../contrib/ethosu/tir_to_cs_translator.py | 164 ++++++------ .../relay/backend/contrib/ethosu/vela_api.py | 50 ++-- .../backend/contrib/ethosu/to_te_graph.cc | 234 ------------------ src/relay/backend/te_compiler_cache.cc | 40 +-- .../contrib/test_ethosu/test_attr_passing.py | 8 +- .../test_ethosu/test_encode_constants.py | 16 +- .../contrib/test_ethosu/test_scheduler.py | 14 +- .../contrib/test_ethosu/test_vela_api.py | 5 +- 15 files changed, 197 insertions(+), 404 deletions(-) delete mode 100644 src/relay/backend/contrib/ethosu/to_te_graph.cc diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index c59a386fefbb..3283e0515c72 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler""" +"""The integration of the Arm(R) Ethos(TM)-U NPU TIR compiler.""" import tvm from tvm import relay from tvm.relay.expr_functor import ExprMutator @@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function - that comprises of a sequence of tir.extern_calls to NPU + that consists of a sequence of tir.extern_calls to NPU operations. Parameters @@ -96,20 +96,20 @@ def lower_ethosu(sch, args, const_dict, name="main"): def lower_to_te(prim_func): - """Lower a Relay primitive function to a Tensor Expression graph. + """Lower a Relay primitive function to a Tensor Expression in an unscheduled CachedFunc. Parameters ---------- prim_func : tvm.relay.Function - The Relay function to lowerethosu_runtime([]). + The Relay function to lower. Returns ------- - out : TEGraph - The lowered Tensor Expression graph. + out : CachedFunc + The lowered Tensor Expression as part of a CachedFunc. """ - f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE") + f = tvm._ffi.get_global_func("relay.backend.LowerToTE") return f(prim_func) @@ -193,7 +193,7 @@ def lower_to_tir(func, cascader=None): func, consts = extract_constants(func) mod = tvm.IRModule.from_expr(func) func = relay.transform.InferType()(mod)["main"] - te_graph = lower_to_te(func) - s = schedule(te_graph, consts, cascader) - mod, consts = lower_ethosu(s, te_graph, consts) + cached_func = lower_to_te(func) + s = schedule(cached_func, consts, cascader) + mod, consts = lower_ethosu(s, cached_func, consts) return mod, consts diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 33fbdcd2b24f..fd7fa293ccfb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the convolution operators in TIR.""" +"""Extract parameters from the convolution operators in TIR.""" import tvm from ..vela_api import SCALE_BIAS_LENGTH from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index ecd402d63309..a116e51c5b7c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the DMA operators in TIR.""" +"""Extract parameters from the DMA operators in TIR.""" import tvm from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs from .spec import SerialFeatureMap, SerialPadding diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 8bb410e986c7..761c8aad7bb1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler""" +"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler.""" import numpy as np # type: ignore import tvm @@ -301,7 +301,7 @@ def EncodeConstants(const_dict): pointer_to_buffer = {} rewrite_buffer = {} rewrite_pointer = {} - accel_type = vela_api.get_target_accel_type() # type: ignore + accel_config = vela_api.get_accelerator_config() def _align_scale_bias(tir_extern_call, bias): """Align the scale_bias to 16 bytes.""" @@ -316,7 +316,7 @@ def _align_scale_bias(tir_extern_call, bias): def _encode_weights(tir_extern_call, weights): """Encode the weights for a TIR extern call.""" - value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type) + value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config) value = np.frombuffer(value_bytes, dtype="uint8") return value diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 5d9027bf2078..7f892d0c602a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -15,17 +15,17 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Different schedulers for Arm(R) Ethos(TM)-U NPU""" +"""Scheduling for Arm(R) Ethos(TM)-U NPU.""" import tvm -def schedule(te_graph, const_dict, cascader=None): - """Schedule a TE graph for NPU compilation. +def schedule(cached_func, const_dict, cascader=None): + """Schedule a CachedFunc for NPU compilation. Parameters ---------- - te_graph - The TE graph to schedule. + cached_func : CachedFunc + The CachedFunc to schedule. const_dict : dict of int to numpy.ndarray The constant dictionary. cascader : callable, optional @@ -38,10 +38,10 @@ def schedule(te_graph, const_dict, cascader=None): The completed schedule for the graph. """ - s = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + s = tvm.te.create_schedule([t.op for t in cached_func.outputs]) if cascader: - cascader(te_graph, const_dict, s) - inline_no_ops(te_graph, s) + cascader(cached_func, const_dict, s) + inline_no_ops(cached_func, s) schedule_pragmas(s) schedule_cache_reads(s) return s @@ -96,7 +96,7 @@ def total_cascader(stripe_size): """ - def _cascader(te_graph, const_dict, sch): + def _cascader(cached_func, const_dict, sch): scheduled = set() def _visit(tensor, stage, ax): @@ -106,8 +106,8 @@ def _visit(tensor, stage, ax): for input_tensor in tensor.op.input_tensors: _visit(input_tensor, stage, ax) - assert len(te_graph.outputs) == 1 - out = te_graph.outputs[0] + assert len(cached_func.outputs) == 1 + out = cached_func.outputs[0] oi, _ = tile_nd(sch, out, stripe_size) for ax in oi: sch[out].unroll(ax) @@ -126,14 +126,14 @@ def copy_constants(): The planning function. """ - def _planner(te_graph, const_dict, sch): + def _planner(cached_func, const_dict, sch): planned = set() # type: ignore def _visit(tensor, reader): if tensor is not planned: planned.add(tensor) if isinstance(tensor.op, tvm.te.PlaceholderOp): - index = list(te_graph.inputs).index(tensor) + index = list(cached_func.inputs).index(tensor) if index in const_dict: sch.cache_read(tensor, "global", [reader]) @@ -141,7 +141,7 @@ def _visit(tensor, reader): for input_tensor in tensor.op.input_tensors: _visit(input_tensor, tensor) - for output_tensor in te_graph.outputs: + for output_tensor in cached_func.outputs: _visit(output_tensor, None) return _planner @@ -216,7 +216,7 @@ def _detect_cache_read(stage): stage.pragma(fax, "op", "ethosu_copy") -def inline_no_ops(te_graph, sch): +def inline_no_ops(cached_func, sch): """Inline 'no-ops' - operations that in principle do nothing. Modifies the schedule in-place. For now we inline reshape and @@ -224,8 +224,8 @@ def inline_no_ops(te_graph, sch): Parameters ---------- - te_graph - The TE graph. + cached_func : CachedFunc + The cached func. sch : tvm.te.Schedule The schedule. @@ -241,7 +241,7 @@ def _visit(tensor): for input_tensor in tensor.op.input_tensors: _visit(input_tensor) - for out in te_graph.outputs: + for out in cached_func.outputs: _visit(out) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 0403ce2c7e8f..f50975c83838 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the transform operators in TIR.""" +"""Extract parameters from the transform operators in TIR.""" import tvm from .spec import SerialCopy from .utils import get_base_address, get_op_attrs diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index ccfc2dfbfc48..de1c0ab19f6e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""Helper utility functions used by the TIR compiler""" +"""Helper utility functions used by the NPU TIR compiler""" import tvm from tvm import arith diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 408eab6427ca..bcae01a10214 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -18,7 +18,7 @@ the Relay to TIR compilation process, to Vela API calls to generate command stream. """ -from typing import NamedTuple +from typing import Dict, NamedTuple, Tuple, Union from enum import auto from enum import Enum import numpy as np # type: ignore @@ -32,7 +32,7 @@ class BufferType(Enum): - """The buffer types the codegen supports""" + """The type of information that a buffer contains.""" constant = auto() input_or_output = auto() @@ -50,7 +50,7 @@ class BufferType(Enum): class BufferInfo(NamedTuple): - """A data structure to hold metadata of the buffer""" + """A data structure to hold metadata of the buffer.""" # If the buffer holds constants, the values will contain that otherwise None values: np.ndarray @@ -90,9 +90,9 @@ def translate(tir_module, params): for extern_call in extern_calls: _npu_ops.append(translate_ethosu_tir_extern_call(extern_call)) _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops) - target_accel_type = vela_api.get_target_accel_type() - cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_type) - payload = vapi.npu_create_driver_payload(cmds, target_accel_type) + target_accel_config = vela_api.get_accelerator_config() + cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) + payload = vapi.npu_create_driver_payload(cmds, target_accel_config) hex_value = "" if constant_tensor is None else constant_tensor.tobytes().hex() return payload.hex(), hex_value, scratch_size @@ -125,9 +125,10 @@ def populate_extern_calls(stmt): return extern_calls -def extract_buffer_info(mod, param_dict): - """ - This function is to read the tvm.IRModule that +def extract_buffer_info( + mod: tvm.IRModule, param_dict: Dict[int, np.ndarray] +) -> Dict[str, BufferInfo]: + """This function is to read the tvm.IRModule that contains Relay to TIR compiled IRModule. Thereafter, this will extract the buffer information as the shape and constant data (if any). @@ -136,12 +137,14 @@ def extract_buffer_info(mod, param_dict): ---------- mod : tvm.IRModule The NPU TIR IRModule. - param_dict : dict + param_dict : Dict[int, np.ndarray] A dictionary containing param idx --> const numpy.NDArray + Returns ------- - dict - a dictionary of buffer names --> BufferInfo + dict : Dict[str, BufferInfo] + A dictionary of buffer names --> BufferInfo + """ buffer_info = dict() # There should only be a single function @@ -328,14 +331,15 @@ def translate_ethosu_copy(tir_extern_call): return _create_npu_dma_op(serial_object) -def _convert_clip_bounds(npu_op): - """ - This function will convert the min and max value +def _convert_clip_bounds(npu_op: vapi.NpuBlockOperation): + """This function will convert the min and max value of clip activations to non quantized floats as expected by the API. + Parameters ---------- - npu_op : ethosu.vela.api.NpuBlockOperation + npu_op : vapi.NpuBlockOperation + """ clip_min_quant = npu_op.activation.min clip_max_quant = npu_op.activation.max @@ -349,13 +353,14 @@ def _convert_clip_bounds(npu_op): npu_op.activation.max = clip_max_actual -def translate_ethosu_conv2d(tir_extern_call): - """This function will translate a tir extern_call - as produced by Relay to TIR compilation. +def translate_ethosu_conv2d(tir_call_extern: tvm.tir.Call) -> Tuple[vapi.NpuConv2DOperation, int]: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + Parameters ---------- - tir_extern_call : tvm.tir.Call - This should be an tir external call that has a agreed upon ordering + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has a agreed upon ordering for TIR Compiler. See Serial2DConvolution in tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. @@ -365,15 +370,18 @@ def translate_ethosu_conv2d(tir_extern_call): The vela object containing the params of ethosu_conv2d weights_zero_point : int The zero point of the weights + """ - # We skip the first element as it is the extern_call function name - serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_extern_call.args[1:]) + # We skip the first element as it is the call_extern function name + serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_call_extern.args[1:]) return _create_npu_op_conv2d(serial_object) -def _create_npu_op_conv2d(serial_2d_convolution): +def _create_npu_op_conv2d( + serial_2d_convolution: spec.Serial2DConvolution, +) -> Tuple[vapi.NpuConv2DOperation, int]: """This is a helper function to capture a list - of arguments to create Vela NpuConv2DOperation object + of arguments to create Vela NpuConv2DOperation object. """ npu_conv2d_op = vapi.NpuConv2DOperation() npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm) @@ -392,8 +400,8 @@ def _create_npu_op_conv2d(serial_2d_convolution): _convert_clip_bounds(npu_conv2d_op) npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) - target_accel_type = vela_api.get_target_accel_type() # type: ignore - block_config = vela_api.get_optimal_block_config(npu_conv2d_op, target_accel_type) + accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_conv2d_op, accel_config) npu_conv2d_op.block_config = block_config weights_shape_ohwi = [ npu_conv2d_op.ofm.shape.depth, @@ -450,16 +458,16 @@ def _create_npu_op_depthwise_conv2d(serial_2d_depthwise): _convert_clip_bounds(npu_depthwise_conv2d_op) npu_depthwise_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale) - target_accel_type = vela_api.get_target_accel_type() - block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_type) + target_accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_config) npu_depthwise_conv2d_op.block_config = block_config return npu_depthwise_conv2d_op, weights_zero_point -def _create_npu_feature_map(serial_feature_map): +def _create_npu_feature_map(serial_feature_map: spec.SerialFeatureMap) -> vapi.NpuFeatureMap: """This is a helper function to capture a list - of arguments to create Vela NpuFeatureMap object + of arguments to create Vela NpuFeatureMap object. """ layout_map = {"NHWC": vapi.NpuLayout.NHWC, "NHCWB16": vapi.NpuLayout.NHCWB16} datatype_map = { @@ -476,14 +484,14 @@ def _create_npu_feature_map(serial_feature_map): nfm = vapi.NpuFeatureMap() nfm.data_type = datatype_map[data_type] nfm.shape = vapi.NpuShape3D( - int(serial_feature_map.height.value), - int(serial_feature_map.width.value), - int(serial_feature_map.channels.value), + int(serial_feature_map.height), + int(serial_feature_map.width), + int(serial_feature_map.channels), ) nfm.tiles = vapi.NpuTileBox( - int(serial_feature_map.tile_height_0.value), - int(serial_feature_map.tile_height_1.value), - int(serial_feature_map.tile_width_0.value), + int(serial_feature_map.tile_height_0), + int(serial_feature_map.tile_height_1), + int(serial_feature_map.tile_width_0), [ serial_feature_map.tile_address_0, serial_feature_map.tile_address_1, @@ -496,81 +504,75 @@ def _create_npu_feature_map(serial_feature_map): ) nfm.layout = layout_map[layout] nfm.strides = vapi.NpuShape3D( - int(serial_feature_map.stride_h.value), - int(serial_feature_map.stride_w.value), - int(serial_feature_map.stride_c.value), + int(serial_feature_map.stride_h), + int(serial_feature_map.stride_w), + int(serial_feature_map.stride_c), ) return nfm -def _create_npu_kernel(serial_kernel): +def _create_npu_kernel(serial_kernel: spec.SerialKernel) -> vapi.NpuKernel: """This is a helper function to capture a list - of arguments to create Vela NpuKernel object + of arguments to create Vela NpuKernel object. """ nknl = vapi.NpuKernel( - w=int(serial_kernel.width.value), - h=int(serial_kernel.height.value), - stride_x=int(serial_kernel.stride_w.value), - stride_y=int(serial_kernel.stride_h.value), - dilation_x=int(serial_kernel.dilation_w.value), - dilation_y=int(serial_kernel.dilation_h.value), + w=int(serial_kernel.width), + h=int(serial_kernel.height), + stride_x=int(serial_kernel.stride_w), + stride_y=int(serial_kernel.stride_h), + dilation_x=int(serial_kernel.dilation_w), + dilation_y=int(serial_kernel.dilation_h), ) return nknl -def _create_npu_address_range(serial_address_range): +def _create_npu_address_range( + serial_address_range: spec.SerialAddressRange, +) -> vapi.NpuAddressRange: """This is a helper function to capture a list - of arguments to create Vela NpuAddressRange object + of arguments to create Vela NpuAddressRange object. """ addr_range = vapi.NpuAddressRange( # region will be updated later region=0, address=serial_address_range.address, - length=int(serial_address_range.length.value), + length=int(serial_address_range.length), ) return addr_range def _create_npu_quantization( - scale, - zero_point, -): + scale: Union[tvm.tir.FloatImm, float], + zero_point: Union[tvm.tir.IntImm, int], +) -> vapi.NpuQuantization: """This is a helper function to capture a list - of arguments to create Vela NpuQuantization object + of arguments to create Vela NpuQuantization object. """ - # Scale could be an ndarray if per-channel quantization is available - if not isinstance(scale, tvm.tir.expr.Load): - if isinstance(scale.value, float): - scale = np.single(scale.value) - else: - assert isinstance(scale.value.value, float) - scale = np.single(scale.value.value) - q_params = vapi.NpuQuantization(scale_f32=scale, zero_point=zero_point.value) - return q_params + return vapi.NpuQuantization(scale_f32=float(scale), zero_point=int(zero_point)) def _create_npu_weights_zero_point( - zero_point, -): - """This is a helper function to capture the weights zero point""" - return zero_point.value + zero_point: Union[int, tvm.tir.IntImm], +) -> int: + """This is a helper function to capture the weights zero point.""" + return int(zero_point) -def _create_npu_padding(serial_padding): +def _create_npu_padding(serial_padding: spec.SerialPadding) -> vapi.NpuPadding: """This is a helper function to capture a list - of arguments to create Vela NpuPadding object""" + of arguments to create Vela NpuPadding object.""" padding = vapi.NpuPadding( - top=int(serial_padding.top.value), - left=int(serial_padding.left.value), - bottom=int(serial_padding.bottom.value), - right=int(serial_padding.right.value), + top=int(serial_padding.top), + left=int(serial_padding.left), + bottom=int(serial_padding.bottom), + right=int(serial_padding.right), ) return padding -def _create_npu_activation(serial_activation): +def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.NpuActivation: """This is a helper function to capture a list - of arguments to create Vela NpuActivation object""" + of arguments to create Vela NpuActivation object.""" if serial_activation.op == "NONE": return None if ( @@ -587,16 +589,16 @@ def _create_npu_activation(serial_activation): op = str(serial_activation.op.value) assert op in op_map.keys() act_op = vapi.NpuActivation(op_map[op]) - act_op.min = int(serial_activation.clip_min.value) - act_op.max = int(serial_activation.clip_max.value) + act_op.min = int(serial_activation.clip_min) + act_op.max = int(serial_activation.clip_max) return act_op def _create_npu_resampling_mode( - mode, -): + mode: str, +) -> vapi.NpuResamplingMode: """This is a helper function to capture a list - of arguments to create Vela NpuResamplingMode object""" + of arguments to create Vela NpuResamplingMode object.""" mode_map = { "NONE": vapi.NpuResamplingMode.NONE, "NEAREST": vapi.NpuResamplingMode.NEAREST, diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 6523352a0eea..69095e43416e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -27,6 +27,7 @@ import numpy as np # type: ignore from ethosu.vela import api as vapi # type: ignore +import tvm from tvm.relay.backend.contrib.ethosu import util # type: ignore from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs @@ -45,7 +46,7 @@ def get_optimal_block_config( - npu_op: vapi.NpuOperation, accel_type: vapi.NpuAccelerator + npu_op: vapi.NpuOperation, accel_config: vapi.NpuAccelerator ) -> vapi.NpuShape3D: """ "The NPU's unit of work is known as a block. It will fetch block(s) from Input @@ -58,15 +59,15 @@ def get_optimal_block_config( ---------- npu_op : ethosu.vela.api.NpuOperation The NPU operation and its params - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config Returns ------- ethosu.vela.api.NpuShape3D : The optimal block config for the operator """ - all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_type) + all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_config) return _get_optimal_block_config(all_valid_block_configs) @@ -112,7 +113,9 @@ def _get_optimal_block_config(all_valid_block_configs: List[vapi.NpuShape3D]) -> return max_area_depth_block_configs[0] -def encode_weights(tir_extern_call, values, accel_type): +def encode_weights( + tir_extern_call: tvm.tir.Call, values: np.ndarray, accel_config: vapi.NpuAccelerator +): """This is an API function to compress weights by passing a tir_extern_call to NPU Convolution operation and values. @@ -122,8 +125,8 @@ def encode_weights(tir_extern_call, values, accel_type): tir_extern_call to NPU Convolution operation values : numpy.ndarray The constant flattened weight data in OHWI layout - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config Returns ------- @@ -137,7 +140,7 @@ def encode_weights(tir_extern_call, values, accel_type): op = str(tir_extern_call.args[0].value) assert op in supported_ops.keys() npu_op, weights_zero_point = supported_ops[op](tir_extern_call) - block_config = get_optimal_block_config(npu_op, accel_type) + block_config = get_optimal_block_config(npu_op, accel_config) # The weight layout is assumed to be flat OHWI, always. assert len(values.shape) == 1 is_depthwise = op == "ethosu_depthwise_conv2d" @@ -157,7 +160,7 @@ def encode_weights(tir_extern_call, values, accel_type): ifm_bitdepth=npu_op.ifm.data_type.size_in_bits(), block_depth=block_config.depth, dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y), - accel_type=accel_type, + accel_config=accel_config, is_depthwise=is_depthwise, ) @@ -169,7 +172,7 @@ def compress_weights( ifm_bitdepth: int, block_depth: int, dilation: Tuple[int, int], - accel_type: vapi.NpuAccelerator, + accel_config: vapi.NpuAccelerator, is_depthwise: Optional[bool] = False, ) -> bytearray: """The NPU requires the weights to be compressed @@ -191,8 +194,8 @@ def compress_weights( The depth of the optimal block config for the operator dilation : tuple A tuple of 2 elements indicating dilation in h and w - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config is_depthwise : bool, Optional This indicates whether the weights are compressed for depthwise convolution @@ -215,7 +218,7 @@ def compress_weights( ] block_traversal = calculate_block_traversal_mode(is_depthwise, shape_ohwi, ifm_bitdepth) compressed_weights = vapi.npu_encode_weights( - accelerator=accel_type, + accelerator=accel_config, weights_volume=weights_ohwi, dilation_xy=dilation, ifm_bitdepth=ifm_bitdepth, @@ -361,15 +364,24 @@ def _calculate_hw_bias_scales( return hw_bias_scales -def get_target_accel_type(): - """This is a helper function to convert cli accelerator type str argument - to NpuAccelerator""" +def get_accelerator_config() -> vapi.NpuAccelerator: + """Get the configuration of the NPU accelerator. + + The configuration string provided as a compiler option is converted into + an NpuAccelerator object. Valid configuration strings: + - 'ethos-u55-256' + - 'ethos-u55-128' + - 'ethos-u55-64' + - 'ethos-u55-32' + + """ npu_accel_str_map = { "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, } - accel_type_str = util.get_accelerator_config() - assert accel_type_str in npu_accel_str_map.keys(), f"{accel_type_str} is not supported" - return npu_accel_str_map[accel_type_str] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str in npu_accel_str_map.keys(), f"{accel_config_str} is not supported" + return npu_accel_str_map[accel_config_str] diff --git a/src/relay/backend/contrib/ethosu/to_te_graph.cc b/src/relay/backend/contrib/ethosu/to_te_graph.cc deleted file mode 100644 index 9646c39da089..000000000000 --- a/src/relay/backend/contrib/ethosu/to_te_graph.cc +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/contrib/ethosu/to_te_graph.cc - * \brief Lower a Relay function to a TE graph. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../compile_engine.h" -#include "../../utils.h" - -namespace tvm { -namespace relay { -namespace contrib { -namespace ethosu { - -/*! \brief Node container to represent a Tensor Expression graph. */ -class TEGraphNode : public Object { - public: - /* \brief The inputs to the graph */ - tvm::Array inputs; - /* \brief The outputs to the graph */ - tvm::Array outputs; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - } - - static constexpr const char* _type_key = "relay.TEGraph"; - TVM_DECLARE_FINAL_OBJECT_INFO(TEGraphNode, Object); -}; - -class TEGraph : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TEGraph, ObjectRef, TEGraphNode); -}; - -TVM_REGISTER_NODE_TYPE(TEGraphNode); - -Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { -#ifndef TVM_INDEX_DEFAULT_I64 - ICHECK_LE(pval[0], std::numeric_limits::max()); - ICHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); -#else - res.push_back(val); -#endif // TVM_INDEX_DEFAULT_I64 - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); - } else { - res.push_back(val); - } - } - return res; -} - -class RelayToTE : public backend::MemoizedExprTranslator> { - public: - RelayToTE() = default; - - TEGraph Lower(const Function& prim_func) { - auto graph_node = make_object(); - for (Var param : prim_func->params) { - Array inputs; - if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - graph_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - graph_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } - } - memo_[param] = inputs; - } - graph_node->outputs = this->VisitExpr(prim_func->body); - return TEGraph(graph_node); - } - - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - return {value}; - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - - Array inputs; - int count_tuple = 0; - for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - - Array outputs; - LoweredOutput lowered_out = - (*flower_call)(GetRef(call_node), inputs, tvm::Target("llvm")); - outputs = lowered_out->outputs; - - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; - } -}; - -TVM_REGISTER_GLOBAL("relay.backend.contrib.ethosu.LowerToTE") - .set_body_typed([](Function prim_func) { return RelayToTE().Lower(prim_func); }); - -} // namespace ethosu -} // namespace contrib -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d0e83765928a..ec87cfc98931 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -111,8 +111,10 @@ Array GetShape(const Array& shape) { // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : public backend::MemoizedExprTranslator> { public: - explicit ScheduleBuilder(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { + explicit ScheduleBuilder(Target target, bool create_schedule = true) + : target_(target), + device_copy_op_(Op::Get("device_copy")), + create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); } @@ -149,7 +151,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator auto prim_fn_var = GlobalVar(prim_fn_name); prim_fn_var->checked_type_ = prim_func->checked_type(); - ICHECK(anchor_op_.defined()); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. // Hence schedule only non PlaceholderOp outputs. @@ -162,7 +163,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator te::Schedule schedule; // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr) { + if (anchor_attrs_.as() == nullptr && create_schedule_) { if (use_auto_scheduler_) { const auto* fauto_schedule = runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); @@ -259,17 +260,19 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator impl = lowered_out->implementation; } - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; + if (create_schedule_) { + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); @@ -334,6 +337,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; + bool create_schedule_; }; /*! @@ -667,6 +671,12 @@ std::string GetUniqueName(std::string name, std::unordered_map return name; } +TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { + return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { + return name; + }); +}); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py b/tests/python/contrib/test_ethosu/test_attr_passing.py index a2fbe1888d2a..6b99a5c1e540 100644 --- a/tests/python/contrib/test_ethosu/test_attr_passing.py +++ b/tests/python/contrib/test_ethosu/test_attr_passing.py @@ -28,7 +28,9 @@ def test_compiler_attr(): } with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethosu.options": config}): with tvm.target.Target("c -device=micro_dev"): - assert util.get_accelerator_config() == config["accelerator_config"] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str == config["accelerator_config"] def test_compiler_attr_default(): @@ -37,7 +39,9 @@ def test_compiler_attr_default(): } with tvm.transform.PassContext(opt_level=3): with tvm.target.Target("c -device=micro_dev"): - assert util.get_accelerator_config() == default_config["accelerator_config"] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str == default_config["accelerator_config"] if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 60ed352edcfd..5b60102162be 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -64,10 +64,10 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, def test_weight_stream_only(): - def _planner(te_graph, const_dict, sch): - weights = te_graph.inputs[1] - bias = te_graph.inputs[2] - out = te_graph.outputs[0] + def _planner(cached_func, const_dict, sch): + weights = cached_func.inputs[1] + bias = cached_func.inputs[2] + out = cached_func.outputs[0] conv_compute = Convolution2DCompute.from_output(out) co = conv_compute.split(sch, 3, 2) cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) @@ -208,10 +208,10 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle def test_mixed_read(): - def _planner(te_graph, const_dict, sch): - weight = te_graph.inputs[4] - scale_bias = te_graph.inputs[5] - out = te_graph.outputs[0] + def _planner(cached_func, const_dict, sch): + weight = cached_func.inputs[4] + scale_bias = cached_func.inputs[5] + out = cached_func.outputs[0] conv_compute = Convolution2DCompute.from_output(out) co = conv_compute.split(sch, 3, 2) cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 8077271ed496..b04059011e8e 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -81,10 +81,10 @@ def test_inline_no_ops(): func = relay.Function(relay.analysis.free_vars(relu2), relu2) func = run_opt_pass(func, relay.transform.InferType()) - te_graph = lower_to_te(func) - sch = te.create_schedule([te_graph.outputs[0].op]) - inline_no_ops(te_graph, sch) - reshape_tensor = te_graph.outputs[0].op.input_tensors[0] + cached_func = lower_to_te(func) + sch = te.create_schedule([cached_func.outputs[0].op]) + inline_no_ops(cached_func, sch) + reshape_tensor = cached_func.outputs[0].op.input_tensors[0] slice_tensor = reshape_tensor.op.input_tensors[0].op.input_tensors[0] assert sch[reshape_tensor].attach_type == AttachType.kInline assert sch[slice_tensor].attach_type == AttachType.kInline @@ -114,11 +114,11 @@ def test_copy_constants(): func = run_opt_pass(func, relay.transform.InferType()) func, const_dict = extract_constants(func) - te_graph = lower_to_te(func) + cached_func = lower_to_te(func) - sch = te.create_schedule([te_graph.outputs[0].op]) + sch = te.create_schedule([cached_func.outputs[0].op]) planner = copy_constants() - planner(te_graph, const_dict, sch) + planner(cached_func, const_dict, sch) assert len(sch.stages) == 21 assert ".global" in sch.stages[5].op.name assert ".global" in sch.stages[7].op.name diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index 02c305387d45..cf845db2b43b 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -354,18 +354,17 @@ def create_mock(test_vec): max = np.iinfo(ifm_dtype).max min = np.iinfo(ifm_dtype).min values = np.random.randint(min, max, test_vec["shape"], ifm_dtype) - compressed_weights = vela_api.compress_weights( + vela_api.compress_weights( weights=values, weights_zp=test_vec["zero_point"], weights_layout=test_vec["layout"], ifm_bitdepth=ifm_bitdepth, block_depth=test_vec["block_depth"], dilation=test_vec["dilation"], - accel_type=test_vec["accel"], + accel_config=test_vec["accel"], is_depthwise=test_vec["is_depthwise"], ) return mock_npu_encode_weights - return None for tv in test_vecs: mock_obj = create_mock(tv)