Skip to content

Commit 109f5b0

Browse files
sergio-grovetyarina-grovety
authored andcommitted
[ETHOSU][MicroNPU][Pass] Add a pass to replicate pads (apache#14909)
Added a pass to to handle the situation when nn.pad operator has more than one qnn.conv2d consumer. pad / \ Conv2D Conv2D In this case, because of the peculiarities of pattern parsing, conv2d does not get into the composite for the NPU. Therefore, pads are added so that each has only one consumer. --------- Co-authored-by: Sergey Smirnov <89378719+sergey-grovety@users.noreply.github.com> Co-authored-by: Arina <117634809+arina-grovety@users.noreply.github.com> Co-authored-by: arina.naumova <naumova@grovety.com>
1 parent c1071b9 commit 109f5b0

File tree

4 files changed

+285
-3
lines changed

4 files changed

+285
-3
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
)
3333
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
3434
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util, vela_api
35-
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
35+
from tvm.relay.expr_functor import ExprMutator, ExprVisitor, Call
36+
from tvm.relay import expr as _expr
3637

3738
# pylint: disable=unused-import
3839
from tvm.relay.backend.contrib.ethosu.op import op_attrs
@@ -357,6 +358,92 @@ def __call__(self, *args, **kwargs):
357358
pass
358359

359360

361+
class PadsWithMultipleConsumersReplicator(ExprMutator):
362+
"""A pass to to handle the situation when nn.pad operator has
363+
more than one qnn.conv2d consumer.
364+
365+
pad
366+
/ \
367+
Conv2D Conv2D
368+
369+
In this case, because of the peculiarities of pattern parsing,
370+
conv2d does not get into the composite for the NPU.
371+
Therefore, pads are added so that each has only one consumer.
372+
"""
373+
374+
def __init__(self):
375+
super().__init__()
376+
# a set to record hashes of an pads which already have one qnn.conv2d consumer
377+
self.hashes = set()
378+
379+
def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
380+
if (
381+
isinstance(call.op, tvm.ir.Op)
382+
and isinstance(call.args[0], Call)
383+
and isinstance(call.args[0].op, tvm.ir.Op)
384+
and call.op == relay.op.get("qnn.conv2d")
385+
and call.args[0].op == relay.op.get("nn.pad")
386+
):
387+
if tvm.ir.structural_hash(call.args[0]) not in self.hashes:
388+
# add the hash of nn.pad to set
389+
self.hashes.add(tvm.ir.structural_hash(call.args[0]))
390+
else:
391+
# if this pad already has a conv2d consumer, duplicate the pad
392+
# and make it an input for current conv2d
393+
used_pad = self.visit(call.args[0])
394+
used_pad_args = [self.visit(arg) for arg in used_pad.args]
395+
new_pad = Call(
396+
used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span
397+
)
398+
new_conv2d_args = []
399+
for i, arg in enumerate(call.args):
400+
if i == 0:
401+
new_conv2d_args.append(self.visit(new_pad))
402+
else:
403+
new_conv2d_args.append(self.visit(arg))
404+
new_conv2d_op = self.visit(call.op)
405+
expr__ = _expr.CallWithFields(
406+
call,
407+
new_conv2d_op,
408+
new_conv2d_args,
409+
call.attrs,
410+
call.type_args,
411+
None,
412+
call.span,
413+
)
414+
return expr__
415+
416+
new_args = [self.visit(arg) for arg in call.args]
417+
new_op = self.visit(call.op)
418+
expr__ = _expr.CallWithFields(
419+
call, new_op, new_args, call.attrs, call.type_args, None, call.span
420+
)
421+
return expr__
422+
423+
424+
def replicate_pads(mod):
425+
"""Traverses the Relay graph to replicate nn.pad operators if thay have
426+
multiple qnn.conv2d consumers. That making remove the situation when
427+
e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped
428+
because several conv2d use the same pad operation.
429+
430+
Parameters
431+
----------
432+
tvm.ir.IRModule
433+
The IRModule that gets generated from a relay frontend.
434+
435+
Returns
436+
-------
437+
tvm.ir.IRModule
438+
The IRModule without nn.pad operators with multiple consumers.
439+
"""
440+
replicator = PadsWithMultipleConsumersReplicator()
441+
for global_var, func in mod.functions.items():
442+
func = replicator.visit(func)
443+
mod.update_func(global_var, func)
444+
return mod
445+
446+
360447
def IdentityOptimizer(): # pylint: disable=invalid-name
361448
"""Pass that removes redundant identities
362449

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2341,13 +2341,15 @@ def partition_for_ethosu(
23412341
mod : IRModule
23422342
The partitioned IRModule with external global functions
23432343
"""
2344-
from tvm.relay.backend.contrib.ethosu import preprocess
2344+
from tvm.relay.backend.contrib.ethosu import preprocess, codegen
23452345

23462346
if params:
23472347
mod["main"] = bind_params_by_name(mod["main"], params)
23482348

23492349
pattern = relay.op.contrib.get_pattern_table("ethos-u")
23502350
mod = relay.transform.InferType()(mod)
2351+
mod = codegen.replicate_pads(mod)
2352+
mod = relay.transform.InferType()(mod)
23512353
mod = relay.transform.MergeComposite(pattern)(mod)
23522354
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
23532355
mod = relay.transform.MergeCompilerRegions()(mod)

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,69 @@ def conv2d_double(x):
157157
infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type)
158158

159159

160+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
161+
@pytest.mark.parametrize(
162+
"op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")]
163+
)
164+
def test_tflite_shared_pad(
165+
accel_type,
166+
op_pairs,
167+
):
168+
np.random.seed(0)
169+
170+
ifm_shape = (1, 55, 32, 3)
171+
kernel_shape = (3, 3)
172+
strides = (3, 2)
173+
dilation = (1, 1)
174+
activation_function = "RELU"
175+
op_padding = "SAME"
176+
sep_padding = (0, 0, 1, 1)
177+
178+
@tf.function
179+
def tf_function(x):
180+
def make_depthwise_or_conv2d(pair_idx, x):
181+
# The input strides to the TensorFlow API needs to be of shape 1x4
182+
tf_strides = [1, strides[0], strides[1], 1]
183+
if op_pairs[pair_idx] == "depthwise":
184+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
185+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
186+
op = tf.nn.depthwise_conv2d(
187+
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
188+
)
189+
else:
190+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
191+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
192+
op = tf.nn.conv2d(
193+
x,
194+
weight,
195+
strides=tf_strides,
196+
padding=op_padding,
197+
dilations=dilation,
198+
)
199+
if activation_function == "RELU":
200+
op = tf.nn.relu(op)
201+
return op
202+
203+
x = tf.pad(
204+
x,
205+
[
206+
[0, 0],
207+
[sep_padding[0], sep_padding[2]],
208+
[sep_padding[1], sep_padding[3]],
209+
[0, 0],
210+
],
211+
"CONSTANT",
212+
)
213+
214+
x1 = make_depthwise_or_conv2d(0, x)
215+
x2 = make_depthwise_or_conv2d(1, x)
216+
217+
x3 = tf.math.add(x1, x2)
218+
return x3
219+
220+
infra.compare_tvm_with_tflite(tf_function, [ifm_shape], accel_type)
221+
222+
160223
@pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)])
161224
def test_out_of_range_scaling(weight_min, weight_max):
162225
np.random.seed(0)

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tvm.relay.backend.contrib.ethosu import legalize, preprocess
3232
from tvm.relay import dataflow_pattern
3333
from tvm.relay.op.contrib import ethosu
34-
from tvm.relay.backend.contrib.ethosu import util
34+
from tvm.relay.backend.contrib.ethosu import util, codegen
3535
from tvm.relay.build_module import bind_params_by_name
3636
from tvm.relay.frontend.tflite import get_pad_value
3737
from tvm.relay.expr_functor import ExprVisitor
@@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table):
4444
want to add the operator's pattern to the pattern table so that the compiler
4545
wouldn't attempt to offload an operator without full stack support."""
4646
mod = relay.transform.InferType()(mod)
47+
mod = mod = codegen.replicate_pads(mod)
48+
mod = relay.transform.InferType()(mod)
4749
mod = relay.transform.MergeComposite(pattern_table)(mod)
4850
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
4951
mod = relay.transform.MergeCompilerRegions()(mod)
@@ -3676,5 +3678,133 @@ def _visit(stmt):
36763678
verify(mod["tvmgen_default_ethos_u_main_0"])
36773679

36783680

3681+
@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)])
3682+
@pytest.mark.parametrize("kernel_shape", [(3, 3)])
3683+
@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))])
3684+
@pytest.mark.parametrize("op_padding", ["SAME", "VALID"])
3685+
@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)])
3686+
@pytest.mark.parametrize(
3687+
"op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")]
3688+
)
3689+
def test_tflite_shared_pad_legalize(
3690+
ifm_shape,
3691+
kernel_shape,
3692+
strides,
3693+
dilation,
3694+
op_padding,
3695+
sep_padding,
3696+
op_pairs,
3697+
):
3698+
dtype = "int8"
3699+
3700+
def create_tflite_graph():
3701+
class Model(tf.Module):
3702+
@tf.function
3703+
def tf_function(self, x):
3704+
def make_depthwise_or_conv2d(pair_idx):
3705+
if op_pairs[pair_idx] == "depthwise":
3706+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
3707+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
3708+
return tf.nn.depthwise_conv2d(
3709+
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
3710+
)
3711+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
3712+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
3713+
return tf.nn.conv2d(
3714+
x,
3715+
weight,
3716+
strides=tf_strides,
3717+
padding=op_padding,
3718+
dilations=dilation,
3719+
)
3720+
3721+
x = tf.pad(
3722+
x,
3723+
[
3724+
[0, 0],
3725+
[sep_padding[0], sep_padding[2]],
3726+
[sep_padding[1], sep_padding[3]],
3727+
[0, 0],
3728+
],
3729+
"CONSTANT",
3730+
)
3731+
3732+
# The input strides to the TensorFlow API needs to be of shape 1x4
3733+
tf_strides = [1, strides[0], strides[1], 1]
3734+
3735+
x1 = make_depthwise_or_conv2d(0)
3736+
x2 = make_depthwise_or_conv2d(1)
3737+
3738+
x3 = tf.math.add(x1, x2)
3739+
return x3
3740+
3741+
model = Model()
3742+
concrete_func = model.tf_function.get_concrete_function(
3743+
tf.TensorSpec(ifm_shape, dtype=tf.float32)
3744+
)
3745+
# Convert the model
3746+
def representative_dataset():
3747+
for _ in range(100):
3748+
data = np.random.rand(*tuple(ifm_shape))
3749+
yield [data.astype(np.float32)]
3750+
3751+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
3752+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
3753+
converter.representative_dataset = representative_dataset
3754+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
3755+
converter.inference_input_type = tf.int8
3756+
converter.inference_output_type = tf.int8
3757+
tflite_model = converter.convert()
3758+
return tflite_model
3759+
3760+
conv2d_pattern_table = [
3761+
(
3762+
ethosu.QnnConv2DParams.composite_name,
3763+
ethosu.qnn_conv2d_pattern(),
3764+
lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
3765+
),
3766+
(
3767+
ethosu.QnnDepthwiseConv2DParams.composite_name,
3768+
ethosu.qnn_depthwise_conv2d_pattern(),
3769+
lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
3770+
),
3771+
]
3772+
3773+
tflite_graph = create_tflite_graph()
3774+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
3775+
3776+
mod, params = relay.frontend.from_tflite(
3777+
tflite_model,
3778+
shape_dict={"input": ifm_shape},
3779+
dtype_dict={"input": dtype},
3780+
)
3781+
3782+
mod["main"] = bind_params_by_name(mod["main"], params)
3783+
mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
3784+
3785+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
3786+
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
3787+
mod["tvmgen_default_ethos_u_main_0"],
3788+
)
3789+
mod["tvmgen_default_ethos_u_main_1"] = dataflow_pattern.rewrite(
3790+
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
3791+
mod["tvmgen_default_ethos_u_main_1"],
3792+
)
3793+
3794+
if op_pairs[0] == "depthwise":
3795+
assert (
3796+
mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.depthwise_conv2d"
3797+
)
3798+
else:
3799+
assert mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.conv2d"
3800+
3801+
if op_pairs[1] == "depthwise":
3802+
assert (
3803+
mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.depthwise_conv2d"
3804+
)
3805+
else:
3806+
assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d"
3807+
3808+
36793809
if __name__ == "__main__":
36803810
tvm.testing.main()

0 commit comments

Comments
 (0)