Skip to content

Commit 7eed325

Browse files
asparkhiylc
authored andcommitted
[1/10] CMSIS-NN graph partitioner for softmax (apache#8653)
* cmsis graph partitioner for softmax Change-Id: I80ecd7bc5351f241b4674ef53b36e4398c8adb83 * Updated docstring in the partioning function Change-Id: Ieb4b623e5929cfdb6aa0235db64c825fac8d7055
1 parent 0074ec9 commit 7eed325

File tree

2 files changed

+187
-0
lines changed

2 files changed

+187
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-argument
18+
"""Arm(R) CMSIS-NN supported operators for Cortex-M."""
19+
import tvm.ir
20+
from tvm.relay import transform
21+
from tvm.relay.build_module import bind_params_by_name
22+
23+
from ...dataflow_pattern import is_constant, is_op, wildcard
24+
from .register import register_pattern_table
25+
26+
27+
def partition_for_cmsisnn(mod, params=None, **opts):
28+
"""Partition the graph greedily offloading supported
29+
operators on Cortex-M using CMSIS-NN
30+
31+
Parameters
32+
----------
33+
mod : Module
34+
The module to run passes on.
35+
params : Optional[Dict[str, NDArray]]
36+
Constant input parameters.
37+
38+
Returns
39+
-------
40+
ret : Module
41+
annotated and partitioned module.
42+
"""
43+
if params:
44+
mod["main"] = bind_params_by_name(mod["main"], params)
45+
46+
seq = tvm.transform.Sequential(
47+
[
48+
transform.InferType(),
49+
transform.MergeComposite(pattern_table()),
50+
transform.AnnotateTarget("cmsisnn"),
51+
transform.MergeCompilerRegions(),
52+
transform.PartitionGraph(),
53+
]
54+
)
55+
56+
return seq(mod)
57+
58+
59+
@register_pattern_table("cmsisnn")
60+
def pattern_table():
61+
"""Get the cmsisnn compiler pattern table."""
62+
63+
def softmax_pattern():
64+
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
65+
pattern = is_op("nn.softmax")(pattern)
66+
pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
67+
return pattern
68+
69+
def check_quantized_softmax(extract):
70+
"""Check if softmax is supported by CMSIS-NN."""
71+
72+
# check for dtypes of quantize and dequantize
73+
return (
74+
extract.attrs.out_dtype == "int8"
75+
and extract.args[0].args[0].args[0].checked_type.dtype == "int8"
76+
)
77+
78+
return [
79+
("cmsisnn.qnn_softmax", softmax_pattern(), check_quantized_softmax),
80+
]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""CMSIS-NN integration tests: softmax"""
19+
20+
import pytest
21+
import sys
22+
23+
import tvm
24+
from tvm import relay
25+
from tvm.relay.op.contrib import cmsisnn
26+
import numpy as np
27+
28+
29+
def count_num_calls(mod):
30+
class CallCounter(relay.ExprVisitor):
31+
def __init__(self):
32+
super().__init__()
33+
self.count = 0
34+
35+
def visit_call(self, call):
36+
if isinstance(call.op, tvm.ir.Op):
37+
self.count += 1
38+
39+
super().visit_call(call)
40+
41+
counter = CallCounter()
42+
for var in mod.get_global_vars():
43+
counter.visit(mod[var.name_hint])
44+
return counter.count
45+
46+
47+
def make_module(func):
48+
func = relay.Function(relay.analysis.free_vars(func), func)
49+
mod = tvm.IRModule.from_expr(func)
50+
return relay.transform.InferType()(mod)
51+
52+
53+
def make_model(shape, zero_point, scale, in_dtype, out_dtype):
54+
a = relay.var("a", shape=shape, dtype=in_dtype)
55+
dequantize = relay.qnn.op.dequantize(
56+
a,
57+
input_scale=relay.const(scale, "float32"),
58+
input_zero_point=relay.const(zero_point, "int32"),
59+
)
60+
softmax = relay.nn.softmax(dequantize)
61+
model = relay.qnn.op.quantize(
62+
softmax,
63+
output_scale=relay.const(scale, "float32"),
64+
output_zero_point=relay.const(zero_point, "int32"),
65+
out_dtype=out_dtype,
66+
)
67+
return model
68+
69+
70+
def test_softmax_int8():
71+
model = make_model([1, 16, 16, 3], 64, 0.02, "int8", "int8")
72+
orig_mod = make_module(model)
73+
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
74+
75+
attrs = [
76+
cmsisnn_mod[var.name_hint].attrs
77+
for var in cmsisnn_mod.get_global_vars()
78+
if cmsisnn_mod[var.name_hint].attrs
79+
]
80+
assert any(attrs), "At least one function with external attributes was expected."
81+
82+
compilers = [
83+
key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items()
84+
]
85+
assert any(compilers), "Module does not contain function for cmsisnn target."
86+
87+
assert count_num_calls(orig_mod) == count_num_calls(
88+
cmsisnn_mod
89+
), "Number of calls changed during partitioning"
90+
91+
92+
@pytest.mark.parametrize("in_dtype,out_dtype", [["uint8", "int8"], ["int8", "uint8"]])
93+
def test_softmax_not_int8(in_dtype, out_dtype):
94+
model = make_model([1, 16, 16, 3], 64, 0.02, in_dtype, out_dtype)
95+
orig_mod = make_module(model)
96+
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
97+
98+
attrs = [
99+
cmsisnn_mod[var.name_hint].attrs
100+
for var in cmsisnn_mod.get_global_vars()
101+
if cmsisnn_mod[var.name_hint].attrs
102+
]
103+
assert not any(attrs), "No function should have an external attribute."
104+
105+
106+
if __name__ == "__main__":
107+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)