Skip to content

Commit

Permalink
test multiple requantize offload
Browse files Browse the repository at this point in the history
Change-Id: I60a3283461a7a7083c05289e84f570698388077b
  • Loading branch information
lhutton1 committed Jan 18, 2022
1 parent 677112f commit 663b1e4
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,5 +1559,49 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


def test_multiple_requantize_offload():
"""
Testing requantize offload in the case one requauntize operation is part of
an existing pattern (in this case Mean: cast->mean->requantize) and the
other is a stand-alone requantize.
"""

def create_model():
ifm = relay.var("input", shape=(1, 3, 3, 4), dtype="int8")
cast = relay.cast(ifm, dtype="int32")
mean = relay.mean(cast, axis=1, keepdims=True)
requantize = relay.qnn.op.requantize(
mean,
input_scale=relay.const(1.0, dtype="float32"),
input_zero_point=relay.const(0, dtype="int32"),
output_scale=relay.const(1.0, dtype="float32"),
output_zero_point=relay.const(0, dtype="int32"),
)
requantize = relay.qnn.op.requantize(
requantize,
input_scale=relay.const(1.0, dtype="float32"),
input_zero_point=relay.const(0, dtype="int32"),
output_scale=relay.const(1.0, dtype="float32"),
output_zero_point=relay.const(0, dtype="int32"),
)
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))

def verify(ext_func):
# If mean operation and separate requantize were offloaded correctly,
# there should only be a pooling operation followed by an identity
# operation leagalized.
op = ext_func.body
assert op.op.name == "contrib.ethosu.identity"
op = op.args[0]
assert ext_func.body.args[0].op.name == "contrib.ethosu.pooling"
op = op.args[0]
assert isinstance(op, relay.Var)

mod = create_model()
mod = ethosu.partition_for_ethosu(mod)
mod = legalize.LegalizeEthosU()(mod)
verify(mod["tvmgen_default_ethos_u_main_0"])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 663b1e4

Please sign in to comment.