Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Add rescale parameters for binary elementwise #13890

Merged
merged 3 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[microNPU] Add rescale parameters for binary elementwise
Rescale parameters have been added for binary elementwise operation in accordance with the Vela API (rescale field in NpuElementWiseOperation https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.5.0/ethosu/vela/api.py#381). This PR is preparation for implementation of softmax operation.
  • Loading branch information
Aleksei-grovety committed Feb 1, 2023
commit eebad4bd7eb6d132d024090f4ea57df4fee56074
18 changes: 18 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def _extract_ethosu_binary_elementwise_params(attrs, args):
ifm2_layout = attrs.ifm2_layout
ofm_layout = attrs.ofm_layout
ofm_dtype = attrs.ofm_dtype
use_rescale = attrs.use_rescale
rescale_scale = attrs.rescale_scale
rescale_shift = attrs.rescale_shift

return (
ifm,
Expand All @@ -73,6 +76,9 @@ def _extract_ethosu_binary_elementwise_params(attrs, args):
ifm2_layout,
ofm_layout,
ofm_dtype,
use_rescale,
rescale_scale,
rescale_shift,
)


Expand Down Expand Up @@ -117,6 +123,9 @@ def ethosu_binary_elementwise(
ifm_layout: Optional[str] = "NHWC",
ifm2_layout: Optional[str] = "NHWC",
ofm_layout: Optional[str] = "NHWC",
use_rescale: bool = False,
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
rescale_scale: int = 0,
rescale_shift: int = 0,
) -> tvm.relay.Call:
"""This is a quantized binary elementwise operation as supported by
the NPU. It accepts either NHWC or NHCWB16 format
Expand Down Expand Up @@ -193,6 +202,12 @@ def ethosu_binary_elementwise(
The layout of the Input Feature Map tensor 2. Can be "NHWC" or "NHCWB16".
ofm_layout : str, optional
The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
use_rescale : bool, optional
True if use explicit scaling.
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
rescale_scale : int, optional
Scale value for rescale. For 32-bit operations scale is not applied but shift is.
rescale_shift : int, optional
Shift value for rescale.

Returns
-------
Expand Down Expand Up @@ -221,4 +236,7 @@ def ethosu_binary_elementwise(
ifm2_layout,
ofm_layout,
ofm_dtype,
use_rescale,
rescale_scale,
rescale_shift,
)
12 changes: 12 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def binary_elementwise_compute(
ifm2_layout: str,
ofm_layout: str,
ofm_dtype: str,
use_rescale: bool,
rescale_scale: int,
rescale_shift: int,
) -> te.Tensor:
"""A compute operator representing the capabilities of binary_elementwise for the NPU.

Expand Down Expand Up @@ -121,6 +124,12 @@ def binary_elementwise_compute(
{int32}->{int8, uint8, int32}, any pairing"
SHL:
{int32}->{int32} only
use_rescale : bool
True if use explicit scaling.
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
rescale_scale : int
Scale value for rescale. For 32-bit operations scale is not applied but shift is.
rescale_shift : int
Shift value for rescale.

Returns
-------
Expand Down Expand Up @@ -153,6 +162,9 @@ def binary_elementwise_compute(
"clip_min": clip_min,
"clip_max": clip_max,
"rounding_mode": rounding_mode,
"use_rescale": use_rescale,
"rescale_scale": rescale_scale,
"rescale_shift": rescale_shift,
}

operators = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tvm
from .utils import get_outer_loops, get_op_attrs
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialActivation, SerialBinaryElementwise
from .spec import SerialActivation, SerialBinaryElementwise, SerialRescaleConfig
from .producers_consumers import ProducersConsumers


Expand Down Expand Up @@ -89,6 +89,9 @@ def get_binary_elementwise_params(
serial_activation = SerialActivation(
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
)
rescale_config = SerialRescaleConfig(
use_rescale=attrs["use_rescale"], scale=attrs["rescale_scale"], shift=attrs["rescale_shift"]
)
return (
SerialBinaryElementwise(
ifm=serial_ifm,
Expand All @@ -99,6 +102,7 @@ def get_binary_elementwise_params(
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
block_config=serial_block_config,
rescale_config=rescale_config,
),
output_pointer,
replace_pointer,
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ def __init__(self, height: int, width: int, depth: int):
self.depth = depth


class SerialRescaleConfig(SerializableFormat):
"""Specialization class to retrieve arguments of a rescale parameters
(to fill in rescale field in Vela NpuElementWiseOperation) on a predefined ordering"""

def __init__(self, use_rescale: bool, scale: int, shift: int):
self.use_rescale = use_rescale
self.scale = scale
self.shift = shift


class Serial2DConvolution(SerializableFormat):
"""Specialization class to retrieve arguments of
a ethosu.conv2d tir extern call on a predefined ordering"""
Expand Down Expand Up @@ -306,6 +316,7 @@ def __init__(
activation: SerialActivation,
rounding_mode: str,
block_config: SerialBlockConfig,
rescale_config: SerialRescaleConfig,
):
self.ifm = ifm
self.ifm2 = ifm2
Expand All @@ -315,6 +326,7 @@ def __init__(
self.activation = activation
self.rounding_mode = rounding_mode
self.block_config = block_config
self.rescale_config = rescale_config


class SerialUnaryElementwise(SerializableFormat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,11 @@ def _create_npu_op_binary_elementwise(serial_binary_elementwise: spec.SerialBina
npu_binary_elementwise_op.ifm2 = _create_npu_feature_map(serial_binary_elementwise.ifm2)
npu_binary_elementwise_op.ofm = _create_npu_feature_map(serial_binary_elementwise.ofm)
npu_binary_elementwise_op.reversed_operands = serial_binary_elementwise.reversed_operands
if serial_binary_elementwise.rescale_config.use_rescale:
npu_binary_elementwise_op.rescale = (
serial_binary_elementwise.rescale_config.scale.value,
serial_binary_elementwise.rescale_config.shift.value,
)

npu_binary_elementwise_op.activation = _create_npu_activation(
serial_binary_elementwise.activation
Expand Down
6 changes: 5 additions & 1 deletion src/relay/op/contrib/ethosu/binary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String operator_
IndexExpr ifm_channels, IndexExpr ifm2_channels,
bool reversed_operands, String activation, int clip_min,
int clip_max, String rounding_mode, String ifm_layout,
String ifm2_layout, String ofm_layout, String ofm_dtype) {
String ifm2_layout, String ofm_layout, String ofm_dtype,
bool use_rescale, int rescale_scale, int rescale_shift) {
auto attrs = make_object<EthosuBinaryElementwiseAttrs>();

attrs->operator_type = std::move(operator_type);
Expand All @@ -113,6 +114,9 @@ Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String operator_
attrs->ifm2_layout = std::move(ifm2_layout);
attrs->ofm_layout = std::move(ofm_layout);
attrs->ofm_dtype = std::move(ofm_dtype);
attrs->use_rescale = use_rescale;
attrs->rescale_scale = rescale_scale;
attrs->rescale_shift = rescale_shift;

static const Op& op = Op::Get("contrib.ethosu.binary_elementwise");
return Call(op, {ifm, ifm2, lut}, Attrs(attrs), {});
Expand Down
10 changes: 10 additions & 0 deletions src/relay/op/contrib/ethosu/op_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ struct EthosuBinaryElementwiseAttrs : public tvm::AttrsNode<EthosuBinaryElementw
String ifm2_layout;
String ofm_layout;
String ofm_dtype;
bool use_rescale;
int rescale_scale;
int rescale_shift;

TVM_DECLARE_ATTRS(EthosuBinaryElementwiseAttrs, "relay.attrs.EthosuBinaryElementwiseAttrs") {
TVM_ATTR_FIELD(operator_type)
Expand Down Expand Up @@ -125,6 +128,13 @@ struct EthosuBinaryElementwiseAttrs : public tvm::AttrsNode<EthosuBinaryElementw
" {int32}->{int8, uint8, int32}, any pairing"
"SHL:"
" {int32}->{int32} only");
TVM_ATTR_FIELD(use_rescale).describe("True if use explicit scaling.").set_default(false);
TVM_ATTR_FIELD(rescale_scale)
.describe(
"Scale value for rescale. "
"For 32-bit operations scale is not applied but shift is.")
.set_default(0);
TVM_ATTR_FIELD(rescale_shift).describe("Shift value for rescale.").set_default(0);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def test_ethosu_binary_elementwise_matcher(
ifm2_layout=ifm2_layout,
ofm_layout=ofm_layout,
ofm_dtype="int8",
use_rescale=False,
rescale_scale=0,
rescale_shift=0,
)
ifm_propagator = out.op.attrs["ifm_propagator"]
ifm2_propagator = out.op.attrs["ifm2_propagator"]
Expand Down
6 changes: 6 additions & 0 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ def make_ethosu_binary_elementwise(
ifm2_layout="NHWC",
ofm_layout="NHWC",
rounding_mode="TFL",
use_rescale: bool = False,
rescale_scale: int = 0,
rescale_shift: int = 0,
):
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
ifm=ifm,
Expand All @@ -714,6 +717,9 @@ def make_ethosu_binary_elementwise(
ifm_layout=ifm_layout,
ifm2_layout=ifm2_layout,
ofm_layout=ofm_layout,
use_rescale=use_rescale,
rescale_scale=rescale_scale,
rescale_shift=rescale_shift,
)
return ethosu_binary_elementwise

Expand Down
65 changes: 65 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,71 @@ def rounding_right_shift(lhs, rhs):
infra.compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
@pytest.mark.parametrize(
"ifm_shape, ifm2_shape, scale, shift, dtype",
[
([1, 1, 1, 16], [1, 1, 1, 16], 5, 2, "int8"),
([1, 2, 3, 1], [1, 1, 3, 1], 2, 1, "int8"),
([1, 5, 1, 8], [1, 1, 1, 8], 1, 2, "int32"),
],
)
def test_ethosu_rescale_mul_binary_elemwise(ifm_shape, ifm2_shape, scale, shift, accel_type, dtype):
np.random.seed(0)

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
rescale_mul_op = infra.make_ethosu_binary_elementwise(
ifm,
ifm2,
ifm_shape[3],
ifm2_shape[3],
"MUL",
dtype,
use_rescale=True,
rescale_scale=scale,
rescale_shift=shift,
)
return tvm.IRModule.from_expr(relay.Function([ifm, ifm2], rescale_mul_op))

def generate_output_data(input_data):
lhs = input_data["ifm"]
rhs = input_data["ifm2"]
rhs = np.broadcast_to(rhs, ifm_shape)

def rounding_right_shift(lhs, shift):
r = 1 << (shift - 1)
return (lhs + r) >> shift

def apply_scale(lhs, scale):
if dtype == "int32":
# For 32-bit operations scale is not applied but shift is
return lhs
else:
return lhs * scale

return [
rounding_right_shift(
apply_scale(np.multiply(lhs.astype("int32"), rhs.astype("int32")), scale), shift
).astype(dtype)
]

cpu_mod = create_model()

# Generate reference data
lhs = np.random.randint(low=-10, high=15, size=ifm_shape, dtype=dtype)
rhs = np.random.randint(low=1, high=5, size=ifm2_shape, dtype=dtype)
input_data = {
"ifm": lhs,
"ifm2": rhs,
}
output_data = {"output": generate_output_data(input_data)[0]}
ethosu_mod = infra.create_ethosu_partition(cpu_mod)

infra.compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(3, 2), (1, 15, 11, 7), (3, 1, 12), (400,)])
@pytest.mark.parametrize("ifm_scale, ifm_zp, ofm_scale, ofm_zp", [(1, 0, 1, 0), (0.015, 3, 0.2, 5)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def _visit(stmt):
),
rounding_mode=rounding_mode,
block_config=spec.SerialBlockConfig(0, 0, 0),
rescale_config=spec.SerialRescaleConfig(False, 0, 0),
)

assert data[0] == ["ethosu_binary_elementwise"] + list(serial_binary_elementwise)
Expand Down Expand Up @@ -335,6 +336,7 @@ def _visit(stmt):
),
rounding_mode=rounding_mode,
block_config=spec.SerialBlockConfig(0, 0, 0),
rescale_config=spec.SerialRescaleConfig(False, 0, 0),
)

assert data[0] == ["ethosu_binary_elementwise"] + list(serial_binary_elementwise)
Expand Down
Loading