diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index b74a09c4fef2..c28e97b0e9d3 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -79,6 +79,27 @@ def check_quantized_softmax(extract): and dequantize_call.args[0].checked_type.dtype == "int8" ) + def mul_pattern(): + """Matcher for QNN multiplication""" + return is_op("qnn.mul")( + wildcard(), + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + + def check_quantized_mul(extract): + """Check if multiply is supported by CMSIS-NN.""" + return ( + extract.args[0].checked_type.dtype == "int8" + and extract.args[1].checked_type.dtype == "int8" + ) + return [ ("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax), + ("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul), ] diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 7c1728ce0ed5..bcb171ca25f8 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -32,17 +32,37 @@ namespace relay { namespace contrib { namespace cmsisnn { -class RelayToTIR : public MixedModeVisitor { +class RelayToTIRVisitor : public MixedModeVisitor { public: - explicit RelayToTIR(String func_name) : func_name_(func_name) {} + explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {} + + tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; } private: - void emit_softmax_tir(const Expr& expr) { + template + const T ArgumentToConstantValue(const Expr& arg) { + const ConstantNode* constant_node = arg.as(); + return static_cast(constant_node->data->data)[0]; + } + + void CreatePrimFuncForExtern(Array func_signature, + tvm::Array call_extern_args) { + Map dict_attrs; + dict_attrs.Set("global_symbol", func_name_); + dict_attrs.Set("tir.noalias", Bool(true)); + + tir::Stmt body = tir::Evaluate( + tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args)); + + primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map(), + DictAttrs(dict_attrs)); + } + + void EmitSoftMax(const Expr& expr) { auto* quantize_call = expr.as(); auto* softmax_call = quantize_call->args[0].as(); auto* dequant_call = softmax_call->args[0].as(); - auto* scale_const = dequant_call->args[1].as(); - const float quant_scale = static_cast(scale_const->data->data)[0]; + const float quant_scale = ArgumentToConstantValue(dequant_call->args[1]); // assuming layout as NHWC auto shape = quantize_call->type_as()->shape; @@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor { IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size), IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift), IntImm(DataType::Int(32), diff_min), out_var}; - tir::Stmt body = - tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args)); - Map dict_attrs; - dict_attrs.Set("global_symbol", func_name_); - dict_attrs.Set("tir.noalias", Bool(true)); + CreatePrimFuncForExtern(func_signature, args); + } - primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map(), - DictAttrs(dict_attrs)); + void EmitMul(const Expr& expr) { + auto* mul_call = expr.as(); + + const float input_0_scale = ArgumentToConstantValue(mul_call->args[2]); + const int32_t input_0_zero_point = ArgumentToConstantValue(mul_call->args[3]); + const float input_1_scale = ArgumentToConstantValue(mul_call->args[4]); + const int32_t input_1_zero_point = ArgumentToConstantValue(mul_call->args[5]); + const float output_scale = ArgumentToConstantValue(mul_call->args[6]); + const int32_t output_zero_point = ArgumentToConstantValue(mul_call->args[7]); + + double quantized_multiplier = static_cast(input_0_scale) * + static_cast(input_1_scale) / + static_cast(output_scale); + auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier); + int32_t output_multiplier = std::get<0>(mult_shift_pair); + int32_t output_shift = std::get<1>(mult_shift_pair); + + PrimExpr tensor_size = mul_call->type_as()->Size(); + + tir::Var input_0("input_0", DataType::Handle(8)); + tir::Var input_1("input_1", DataType::Handle(8)); + tir::Var output("output", DataType::Handle(8)); + + Array func_signature{input_0, input_1, output}; + + tvm::Array args = { + tir::StringImm("arm_elementwise_mul_s8"), + input_0, + input_1, + IntImm(DataType::Int(32), -input_0_zero_point), + IntImm(DataType::Int(32), -input_1_zero_point), + output, + IntImm(DataType::Int(32), output_zero_point), + IntImm(DataType::Int(32), output_multiplier), + IntImm(DataType::Int(32), output_shift), + IntImm(DataType::Int(32), std::numeric_limits::min()), + IntImm(DataType::Int(32), std::numeric_limits::max()), + tensor_size, + }; + + CreatePrimFuncForExtern(func_signature, args); } void VisitExpr_(const CallNode* call) final { @@ -98,7 +154,10 @@ class RelayToTIR : public MixedModeVisitor { auto comp_name = func->GetAttr(attr::kComposite); if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") { - emit_softmax_tir(func->body); + EmitSoftMax(func->body); + } + if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") { + EmitMul(func->body); } } @@ -119,12 +178,12 @@ IRModule GenerateTIR(IRModule mod) { } // Prepare PrimFunc from Relay Function - auto relay_to_tir = RelayToTIR(func_name); + auto relay_to_tir = RelayToTIRVisitor(func_name); relay_to_tir.VisitExpr(func->body); // Build the TIR IRModule from the generated PrimFunc Map var_func_map; - var_func_map.Set(GlobalVar(func_name), relay_to_tir.primfunc_); + var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc()); return IRModule(var_func_map); } diff --git a/tests/python/contrib/test_cmsisnn/test_mul.py b/tests/python/contrib/test_cmsisnn/test_mul.py new file mode 100644 index 000000000000..88fbeb2dfcfe --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_mul.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: mul""" + +import sys + +import numpy as np +import pytest + +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + +from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str +from tests.python.relay.aot.aot_test_utils import ( + AOTTestModel, + AOT_CORSTONE300_RUNNER, + generate_ref_data, + compile_and_run, +) + + +def make_model( + shape, + input_0_dtype, + input_1_dtype, + input_0_scale, + input_0_zero_point, + input_1_scale, + input_1_zero_point, + out_scale=1.0 / 256, + out_zero_point=-128, +): + """Create a Relay Function / network model""" + + return relay.qnn.op.mul( + relay.var("input_0", shape=shape, dtype=input_0_dtype), + relay.var("input_1", shape=shape, dtype=input_1_dtype), + relay.const(input_0_scale, "float32"), + relay.const(input_0_zero_point, "int32"), + relay.const(input_1_scale, "float32"), + relay.const(input_1_zero_point, "int32"), + relay.const(out_scale, "float32"), + relay.const(out_zero_point, "int32"), + ) + + +@skip_if_no_reference_system +@pytest.mark.parametrize( + [ + "input_0_scale", + "input_0_zero_point", + "input_1_scale", + "input_1_zero_point", + "output_tolerance", + ], + [[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]], +) +def test_mul_int8( + input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + dtype = "int8" + shape = [1, 16, 16, 3] + model = make_model( + shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + in_min, in_max = get_range_for_dtype_str(dtype) + inputs = { + "input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), + "input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), + } + output_list = generate_ref_data(orig_mod["main"], inputs) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + output_tolerance=output_tolerance, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]]) +def test_invalid_parameters( + input_dtype, +): + input_scale = 0.256 + input_zero_point = 33 + model = make_model( + [1, 16, 16, 3], + input_dtype, + input_dtype, + input_scale, + input_zero_point, + input_scale, + input_zero_point, + ) + + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert not any(attrs), "No function should have an external attribute." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py index 1f6e0e711f0c..b14a15c60c8b 100644 --- a/tests/python/contrib/test_cmsisnn/test_networks.py +++ b/tests/python/contrib/test_cmsisnn/test_networks.py @@ -17,18 +17,16 @@ """CMSIS-NN: testing with networks""" -import platform import sys -import os -import pathlib -import tvm + +import numpy as np +import pytest + from tvm import relay from tvm.contrib.download import download_testdata from tvm.relay.op.contrib import cmsisnn -import numpy as np -import pytest -import itertools +from utils import skip_if_no_reference_system, get_range_for_dtype_str from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, @@ -37,30 +35,6 @@ ) -def get_range_for_dtype_str(dtype): - """ - Produce the min,max for a give data type. - - Parameters - ---------- - dtype : str - a type string (e.g., int8) - - Returns - ------- - type_info.min : int - the minimum of the range - type_info.max : int - the maximum of the range - """ - - try: - type_info = np.iinfo(dtype) - except ValueError: - type_info = np.finfo(dtype) - return type_info.min, type_info.max - - def convert_to_relay( tflite_model_buf, input_data, @@ -99,9 +73,7 @@ def convert_to_list(x): return mod, params -@pytest.mark.skipif( - platform.machine() == "i686", reason="Reference system unavailable in i386 container" -) +@skip_if_no_reference_system def test_cnn_small(): # download the model base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8" diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index c1951d1f2ce5..12e11c381c4f 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -17,17 +17,21 @@ """CMSIS-NN integration tests: softmax""" -import platform import sys -import os -import pathlib -import tvm -from tvm import relay -from tvm.relay.op.contrib import cmsisnn +import itertools + import numpy as np import pytest -import itertools +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + +from utils import ( + skip_if_no_reference_system, + make_module, + count_num_calls, + get_range_for_dtype_str, +) from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, @@ -36,61 +40,9 @@ ) -def get_range_for_dtype_str(dtype): - """ - Produce the min,max for a give data type. - - Parameters - ---------- - dtype : str - a type string (e.g., int8) - - Returns - ------- - type_info.min : int - the minimum of the range - type_info.max : int - the maximum of the range - """ - - try: - type_info = np.iinfo(dtype) - except ValueError: - type_info = np.finfo(dtype) - return type_info.min, type_info.max - - -def count_num_calls(mod): - """Count number of CallNode in the IRModule""" - - class CallCounter(relay.ExprVisitor): - def __init__(self): - super().__init__() - self.count = 0 - - def visit_call(self, call): - if isinstance(call.op, tvm.ir.Op): - self.count += 1 - - super().visit_call(call) - - counter = CallCounter() - for var in mod.get_global_vars(): - counter.visit(mod[var.name_hint]) - return counter.count - - -def make_module(func): - """Create IRModule from Function""" - func = relay.Function(relay.analysis.free_vars(func), func) - mod = tvm.IRModule.from_expr(func) - return relay.transform.InferType()(mod) - - def make_model( shape, in_dtype, out_dtype, in_zero_point, in_scale, out_zero_point=-128, out_scale=1.0 / 256 ): - """Create a Relay Function / network model""" a = relay.var("in0", shape=shape, dtype=in_dtype) dequantize = relay.qnn.op.dequantize( @@ -108,9 +60,7 @@ def make_model( return model -@pytest.mark.skipif( - platform.machine() == "i686", reason="Reference system unavailable in i386 container" -) +@skip_if_no_reference_system @pytest.mark.parametrize(["zero_point", "scale"], [[33, 0.256], [-64, 0.0128]]) def test_softmax_int8(zero_point, scale): interface_api = "c" diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py new file mode 100644 index 000000000000..3fd12efd3367 --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN functions for testing networks""" + +import platform + +import numpy as np +import pytest + +import tvm +from tvm import relay + + +def skip_if_no_reference_system(func): + return pytest.mark.skipif( + platform.machine() == "i686", reason="Reference system unavailable in i386 container" + )(func) + + +def count_num_calls(mod): + """Count number of CallNode in the IRModule""" + + class CallCounter(relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + counter = CallCounter() + for var in mod.get_global_vars(): + counter.visit(mod[var.name_hint]) + return counter.count + + +def get_range_for_dtype_str(dtype): + """ + Produce the min,max for a give data type. + + Parameters + ---------- + dtype : str + a type string (e.g., int8) + + Returns + ------- + type_info.min : int + the minimum of the range + type_info.max : int + the maximum of the range + """ + + try: + type_info = np.iinfo(dtype) + except ValueError: + type_info = np.finfo(dtype) + return type_info.min, type_info.max + + +def make_module(func): + """Create IRModule from Function""" + func = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(func) + return relay.transform.InferType()(mod)