Skip to content

Commit 0f68cd2

Browse files
committed
NXP backend: Use zero_point to pad quantized average_pool.
1 parent 8aea9fb commit 0f68cd2

File tree

4 files changed

+73
-5
lines changed

4 files changed

+73
-5
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import numpy as np
7+
68
from executorch.backends.nxp.backend.ir.converter.conversion import (
79
aten_translator,
810
common,
911
)
1012
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
13+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
14+
tf_lite_type_to_numpy,
15+
)
1116
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
1217
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
1318
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
@@ -57,9 +62,20 @@ def _convert_2d_avg_pool(
5762
)
5863

5964
if explicit_padding is not None:
60-
# Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation!
65+
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). But these will
66+
# be included in the computation!
67+
input_quantization = t_op.tmp_inputs[0].quantization
68+
pad_value = (
69+
None
70+
if input_quantization is None
71+
else np.array(input_quantization.zero_point[0]).astype(
72+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
73+
)
74+
)
6175
ops.add_pre(
62-
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
76+
self.builder.create_pad_operator_before(
77+
t_op, 0, explicit_padding, pad_value
78+
)
6379
)
6480

6581
return ops.flatten()

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import numpy as np
77
import torch
8-
from torch.fx import Node
9-
from torch.nn import Parameter
108

119
from executorch.backends.nxp.backend.edge_helper import (
1210
input_tensor,
@@ -42,6 +40,8 @@
4240
conv_2d_options,
4341
depthwise_conv_2d_options,
4442
)
43+
from torch.fx import Node
44+
from torch.nn import Parameter
4545

4646

4747
class ConvolutionConverter(NodeConverter):

backends/nxp/tests/executorch_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
4848

4949
def to_quantized_edge_program(
5050
model: torch.nn.Module,
51-
input_shapes: tuple[int] | list[tuple[int]],
51+
input_shapes: tuple[int, ...] | list[tuple[int, ...]],
5252
operators_not_to_delegate: list[str] = None,
5353
target="imxrt700",
5454
neutron_converter_flavor="SDK_25_03",

backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
from executorch.backends.nxp.backend.edge_program_converter import (
1111
EdgeProgramToIRConverter,
1212
)
13+
from executorch.backends.nxp.backend.ir.converter.builder.model_builder import (
14+
ModelBuilder,
15+
)
16+
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
17+
BuiltinOperator,
18+
)
1319
from executorch.backends.nxp.tests.executorch_pipeline import (
1420
to_edge_program,
1521
to_quantized_edge_program,
@@ -156,3 +162,49 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ
156162
tflite_output_preprocess=ToNCHWPreprocess(),
157163
input_data=input_data,
158164
)
165+
166+
167+
def test_avg_pool_2d_quant_conversion__padded(mocker):
168+
input_shape = (1, 8, 8, 8)
169+
model = AvgPool2dModule(True, 1)
170+
171+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
172+
ops_spy = mocker.spy(ModelBuilder, "finish")
173+
174+
# Run conversion
175+
_ = to_quantized_edge_program(model, input_shape)
176+
177+
# Capture the converter operators.
178+
ops = ops_spy.spy_return.sub_graphs[0].operators.vector
179+
180+
# Capture generated model
181+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
182+
183+
# Capture converted program
184+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
185+
186+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
187+
188+
convert_run_compare(
189+
exported_program,
190+
tflite_input_preprocess=ToNHWCPreprocess(),
191+
tfl_model=tflite_flatbuffers_model,
192+
tflite_output_preprocess=ToNCHWPreprocess(),
193+
input_data=input_data,
194+
)
195+
196+
assert len(ops) == 2
197+
assert ops[0].builtin_options.operator_type == BuiltinOperator.PADV2
198+
assert ops[1].builtin_options.operator_type == BuiltinOperator.AVERAGE_POOL_2D
199+
200+
# Make sure the padding used the `zero-point`.
201+
pad_value = ops[0].tmp_inputs[2].tmp_buffer.data.item()
202+
assert (
203+
pad_value == ops[0].tmp_inputs[0].quantization.zero_point[0]
204+
) # `Pad` input zp.
205+
assert (
206+
pad_value == ops[0].tmp_outputs[0].quantization.zero_point[0]
207+
) # `Pad` output zp.
208+
assert (
209+
pad_value == ops[1].tmp_inputs[0].quantization.zero_point[0]
210+
) # `AvgPool` input zp.

0 commit comments

Comments
 (0)