Skip to content

Commit

Permalink
[microNPU] Add support for requantize (apache#9910)
Browse files Browse the repository at this point in the history
* [microNPU] Add support for requantize

Adds support for stand-alone requantize operation which is legalized to
an identity operation on the NPU.

Change-Id: Ie2450c5fc72f405eddf517593236074aa4716c3b

* fix concatenate tests failing due to not being bit exact

Since requantize is now offloaded, concatenate tests were failing
due a reference not being used.

Change-Id: I44b26b5daecfefb776ca19e6646f3690f5570f52

* test multiple requantize offload

Change-Id: I60a3283461a7a7083c05289e84f570698388077b

* address comments

Change-Id: I7196a0fa468eb7c6a96f2b8a68f3a2dcf5a5693c
  • Loading branch information
lhutton1 authored and ylc committed Feb 16, 2022
1 parent 0c91640 commit 2a068d3
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 1 deletion.
44 changes: 44 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,49 @@ def __call__(self, *args, **kwargs):
pass


class RequantizeRewriter(DFPatternCallback):
"""Convert ethos-u.requantize composite function to an identity operation."""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.RequantizeParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.RequantizeParams(post.op.body)
params.ifm.tensor = post.args[0]

lut = relay.const([], "int8")

return ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeRequantize:
"""This is the pass that wraps RequantizeRewriter."""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(RequantizeRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand Down Expand Up @@ -1255,6 +1298,7 @@ def transform_module(
mod = LegalizeMean()(mod)
mod = LegalizeConcat()(mod)
mod = LegalizeSigmoid()(mod)
mod = LegalizeRequantize()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
59 changes: 59 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,60 @@ def split_pattern():
return split


class RequantizeParams:
"""
This class will parse a call to ethos-u.requantize composite function
and extract the parameter information.
"""

composite_name = "ethos-u.requantize"

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

layout = "NHWC"
in_var = func_body.args[0]
requantize = func_body

self.ifm = TensorParams(
in_var,
layout=layout,
scale=requantize.args[RequantArgs.IFM_SCALE.value],
zero_point=requantize.args[RequantArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
requantize,
layout=layout,
scale=requantize.args[RequantArgs.OFM_SCALE.value],
zero_point=requantize.args[RequantArgs.OFM_ZERO_POINT.value],
)

attrs = requantize.attrs
self.out_dtype = attrs.out_dtype

def is_valid(self) -> bool:
"""
Checks whether qnn.requantize has compatible attributes with HW.
"""
tensor_params = [self.ifm, self.ofm]
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
return False
if not check_dimensions(self.ifm) or not check_dimensions(self.ofm):
return False
if self.out_dtype and self.out_dtype != "int8":
return False
return True


def requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for qnn.requantize.
"""
return is_op("qnn.requantize")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant()
)


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand Down Expand Up @@ -1230,6 +1284,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
split_pattern(),
lambda pat: SplitParams(pat).is_valid(),
),
(
RequantizeParams.composite_name,
requantize_pattern(),
lambda pat: RequantizeParams(pat).is_valid(),
),
]


Expand Down
35 changes: 34 additions & 1 deletion tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,10 @@ def concat_func(*inputs):
op = tf.concat(list(inputs), axis)
return op

_compare_tvm_with_tflite(concat_func, shapes, accel_type)
# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
_compare_tvm_with_tflite(concat_func, shapes, accel_type, output_tolerance=1)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
Expand Down Expand Up @@ -987,5 +990,35 @@ def split_func(x):
_compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
[
[(1, 8, 8, 3), 1.0, 0, 1.0, 0],
[(1, 20, 30, 3), 1.345, 34, 0.32, -23],
],
)
def test_ethosu_requantize(accel_type, ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp):
dtype = "int8"

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
requantize = relay.qnn.op.requantize(
ifm,
relay.const(ifm_scale, dtype="float32"),
relay.const(ifm_zp, dtype="int32"),
relay.const(ofm_scale, dtype="float32"),
relay.const(ofm_zp, dtype="int32"),
)
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))

cpu_mod = create_model()
input_data = {"ifm": np.random.randint(-128, high=127, size=ifm_shape, dtype=dtype)}
output_data = generate_ref_data(cpu_mod, input_data)
ethosu_mod = partition_for_ethosu(cpu_mod)

_compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


if __name__ == "__main__":
pytest.main([__file__])
100 changes: 100 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
pytest.importorskip("ethosu.vela")

import math

import numpy as np
import tensorflow as tf
import tflite.Model
Expand Down Expand Up @@ -1502,5 +1503,104 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize(
"ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
[[(1, 8, 8, 3), 1.0, 0, 1.0, 0], [(1, 20, 30, 3), 1.345, 34, 0.32, -23]],
)
def test_ethosu_requantize(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp):
dtype = "int8"

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
requantize = relay.qnn.op.requantize(
ifm,
relay.const(ifm_scale, dtype="float32"),
relay.const(ifm_zp, dtype="int32"),
relay.const(ofm_scale, dtype="float32"),
relay.const(ofm_zp, dtype="int32"),
)
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))

def verify(ext_func):
op = ext_func.body

# Check IFM
ifm = op.args[0].checked_type
assert list(ifm.shape) == list(ifm_shape)
assert str(ifm.dtype) == dtype

# Check OFM
ofm = op.checked_type
assert list(ofm.shape) == list(ifm_shape)
assert str(ofm.dtype) == dtype

# Check quantization params
assert math.isclose(op.attrs.ifm_scale, ifm_scale, abs_tol=1e-7)
assert op.attrs.ifm_zero_point == ifm_zp
assert math.isclose(op.attrs.ofm_scale, ofm_scale, abs_tol=1e-7)
assert op.attrs.ofm_zero_point == ofm_zp

rewriter = legalize.RequantizeRewriter()
pattern_table = [
(
ethosu.RequantizeParams.composite_name,
ethosu.requantize_pattern(),
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
),
]

mod = create_model()
mod = partition_ethosu_by_table(mod, pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
rewriter, mod["tvmgen_default_ethos_u_main_0"]
)
verify(mod["tvmgen_default_ethos_u_main_0"])


def test_multiple_requantize_offload():
"""
Testing requantize offload in the case one requantize operation is part of
an existing pattern (in this case Mean: cast->mean->requantize) and the
other is a stand-alone requantize.
"""

def create_model():
ifm = relay.var("input", shape=(1, 3, 3, 4), dtype="int8")
cast = relay.cast(ifm, dtype="int32")
mean = relay.mean(cast, axis=1, keepdims=True)
requantize = relay.qnn.op.requantize(
mean,
input_scale=relay.const(1.0, dtype="float32"),
input_zero_point=relay.const(0, dtype="int32"),
output_scale=relay.const(1.0, dtype="float32"),
output_zero_point=relay.const(0, dtype="int32"),
)
requantize = relay.qnn.op.requantize(
requantize,
input_scale=relay.const(1.0, dtype="float32"),
input_zero_point=relay.const(0, dtype="int32"),
output_scale=relay.const(1.0, dtype="float32"),
output_zero_point=relay.const(0, dtype="int32"),
)
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))

def verify(ext_func):
# If mean operation and separate requantize were offloaded correctly,
# there should only be a pooling operation followed by an identity
# operation leagalized.
op = ext_func.body
assert op.op.name == "contrib.ethosu.identity"
op = op.args[0]
assert ext_func.body.args[0].op.name == "contrib.ethosu.pooling"
op = op.args[0]
assert isinstance(op, relay.Var)

mod = create_model()
mod = ethosu.partition_for_ethosu(mod)
mod = legalize.LegalizeEthosU()(mod)
verify(mod["tvmgen_default_ethos_u_main_0"])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 2a068d3

Please sign in to comment.