Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CMSIS-NN] Initial operator support for Mul #9163

Merged
merged 1 commit into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ def check_quantized_softmax(extract):
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def mul_pattern():
"""Matcher for QNN multiplication"""
return is_op("qnn.mul")(
wildcard(),
wildcard(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)

def check_quantized_mul(extract):
"""Check if multiply is supported by CMSIS-NN."""
return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
)

return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul),
]
89 changes: 74 additions & 15 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,37 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

class RelayToTIR : public MixedModeVisitor {
class RelayToTIRVisitor : public MixedModeVisitor {
public:
explicit RelayToTIR(String func_name) : func_name_(func_name) {}
explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}

tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }

private:
void emit_softmax_tir(const Expr& expr) {
template <typename T>
const T ArgumentToConstantValue(const Expr& arg) {
Mousius marked this conversation as resolved.
Show resolved Hide resolved
const ConstantNode* constant_node = arg.as<ConstantNode>();
return static_cast<const T*>(constant_node->data->data)[0];
}

void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
Mousius marked this conversation as resolved.
Show resolved Hide resolved
tvm::Array<PrimExpr> call_extern_args) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set("tir.noalias", Bool(true));

tir::Stmt body = tir::Evaluate(
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
}

void EmitSoftMax(const Expr& expr) {
auto* quantize_call = expr.as<CallNode>();
auto* softmax_call = quantize_call->args[0].as<CallNode>();
auto* dequant_call = softmax_call->args[0].as<CallNode>();
auto* scale_const = dequant_call->args[1].as<ConstantNode>();
const float quant_scale = static_cast<const float*>(scale_const->data->data)[0];
const float quant_scale = ArgumentToConstantValue<float>(dequant_call->args[1]);
Mousius marked this conversation as resolved.
Show resolved Hide resolved

// assuming layout as NHWC
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
Expand Down Expand Up @@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor {
IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size),
IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift),
IntImm(DataType::Int(32), diff_min), out_var};
tir::Stmt body =
tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args));

Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set("tir.noalias", Bool(true));
CreatePrimFuncForExtern(func_signature, args);
}

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
void EmitMul(const Expr& expr) {
auto* mul_call = expr.as<CallNode>();

const float input_0_scale = ArgumentToConstantValue<float>(mul_call->args[2]);
const int32_t input_0_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[3]);
const float input_1_scale = ArgumentToConstantValue<float>(mul_call->args[4]);
const int32_t input_1_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[5]);
const float output_scale = ArgumentToConstantValue<float>(mul_call->args[6]);
const int32_t output_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[7]);

double quantized_multiplier = static_cast<double>(input_0_scale) *
static_cast<double>(input_1_scale) /
static_cast<double>(output_scale);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
int32_t output_multiplier = std::get<0>(mult_shift_pair);
int32_t output_shift = std::get<1>(mult_shift_pair);

PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();

tir::Var input_0("input_0", DataType::Handle(8));
Mousius marked this conversation as resolved.
Show resolved Hide resolved
tir::Var input_1("input_1", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));

Array<tir::Var> func_signature{input_0, input_1, output};

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_mul_s8"),
input_0,
input_1,
IntImm(DataType::Int(32), -input_0_zero_point),
IntImm(DataType::Int(32), -input_1_zero_point),
output,
IntImm(DataType::Int(32), output_zero_point),
IntImm(DataType::Int(32), output_multiplier),
IntImm(DataType::Int(32), output_shift),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::min()),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::max()),
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
}

void VisitExpr_(const CallNode* call) final {
Expand All @@ -98,7 +154,10 @@ class RelayToTIR : public MixedModeVisitor {

auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") {
emit_softmax_tir(func->body);
EmitSoftMax(func->body);
}
if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") {
EmitMul(func->body);
}
}

Expand All @@ -119,12 +178,12 @@ IRModule GenerateTIR(IRModule mod) {
}

// Prepare PrimFunc from Relay Function
auto relay_to_tir = RelayToTIR(func_name);
auto relay_to_tir = RelayToTIRVisitor(func_name);
relay_to_tir.VisitExpr(func->body);

// Build the TIR IRModule from the generated PrimFunc
Map<GlobalVar, BaseFunc> var_func_map;
var_func_map.Set(GlobalVar(func_name), relay_to_tir.primfunc_);
var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc());
return IRModule(var_func_map);
}

Expand Down
154 changes: 154 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""CMSIS-NN integration tests: mul"""

import sys

import numpy as np
import pytest

from tvm import relay
from tvm.relay.op.contrib import cmsisnn

from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
generate_ref_data,
compile_and_run,
)


def make_model(
shape,
input_0_dtype,
input_1_dtype,
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
out_scale=1.0 / 256,
out_zero_point=-128,
):
"""Create a Relay Function / network model"""

return relay.qnn.op.mul(
relay.var("input_0", shape=shape, dtype=input_0_dtype),
relay.var("input_1", shape=shape, dtype=input_1_dtype),
relay.const(input_0_scale, "float32"),
relay.const(input_0_zero_point, "int32"),
relay.const(input_1_scale, "float32"),
relay.const(input_1_zero_point, "int32"),
relay.const(out_scale, "float32"),
relay.const(out_zero_point, "int32"),
)


@skip_if_no_reference_system
@pytest.mark.parametrize(
[
"input_0_scale",
"input_0_zero_point",
"input_1_scale",
"input_1_zero_point",
"output_tolerance",
],
[[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]],
)
def test_mul_int8(
input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance
):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER

dtype = "int8"
shape = [1, 16, 16, 3]
model = make_model(
shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
inputs = {
"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
output_tolerance=output_tolerance,
),
test_runner,
interface_api,
use_unpacked_api,
)


@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
def test_invalid_parameters(
input_dtype,
):
input_scale = 0.256
input_zero_point = 33
model = make_model(
[1, 16, 16, 3],
input_dtype,
input_dtype,
input_scale,
input_zero_point,
input_scale,
input_zero_point,
)

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
40 changes: 6 additions & 34 deletions tests/python/contrib/test_cmsisnn/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@

"""CMSIS-NN: testing with networks"""

import platform
import sys
import os
import pathlib
import tvm

import numpy as np
import pytest

from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib import cmsisnn
import numpy as np
import pytest
import itertools

from utils import skip_if_no_reference_system, get_range_for_dtype_str
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
Expand All @@ -37,30 +35,6 @@
)


def get_range_for_dtype_str(dtype):
"""
Produce the min,max for a give data type.

Parameters
----------
dtype : str
a type string (e.g., int8)

Returns
-------
type_info.min : int
the minimum of the range
type_info.max : int
the maximum of the range
"""

try:
type_info = np.iinfo(dtype)
except ValueError:
type_info = np.finfo(dtype)
return type_info.min, type_info.max


def convert_to_relay(
tflite_model_buf,
input_data,
Expand Down Expand Up @@ -99,9 +73,7 @@ def convert_to_list(x):
return mod, params


@pytest.mark.skipif(
platform.machine() == "i686", reason="Reference system unavailable in i386 container"
)
@skip_if_no_reference_system
def test_cnn_small():
# download the model
base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8"
Expand Down
Loading