Skip to content

Commit afd78fc

Browse files
committed
Update on "Dtype selective build: fail if not xplat, if portable/optimized not in kernel_deps"
#10985 Try to make user error harder for dtype selective build. Emit warning for now, as too many failures when set to failure :( For example: ``` buck2 build //xplat/sgr/resources/tests/handwriting:pkg buck2 build fbsource//xplat/sgr/resources/mwa:main_pkg_libAndroid Differential Revision: [D75027794](https://our.internmc.facebook.com/intern/diff/D75027794/) [ghstack-poisoned]
2 parents 95df96b + 61c41e2 commit afd78fc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1567
-645
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
register_node_visitor,
1818
)
1919
from executorch.backends.arm.operators.operator_validation_utils import (
20+
adjust_pooling_pad_if_needed,
2021
validate_num_inputs,
2122
validate_same_dtype,
2223
)
@@ -63,6 +64,20 @@ def _build_generic_avgpool2d(
6364
except IndexError:
6465
pad_size_list = [0, 0, 0, 0]
6566

67+
# Adjust the padding as necessary
68+
pad_size_list[1] = adjust_pooling_pad_if_needed(
69+
input_tensor.shape[2],
70+
kernel_size_list[0],
71+
stride_size_list[0],
72+
pad_size_list[1],
73+
)
74+
pad_size_list[3] = adjust_pooling_pad_if_needed(
75+
input_tensor.shape[3],
76+
kernel_size_list[1],
77+
stride_size_list[1],
78+
pad_size_list[3],
79+
)
80+
6681
attr = ts.TosaSerializerAttribute()
6782
attr.PoolAttribute(
6883
kernel=kernel_size_list,
@@ -192,6 +207,20 @@ def _build_generic_avgpool2d(
192207
except IndexError:
193208
pad_size_list = [0, 0, 0, 0]
194209

210+
# Adjust the padding as necessary
211+
pad_size_list[1] = adjust_pooling_pad_if_needed(
212+
input_tensor.shape[2],
213+
kernel_size_list[0],
214+
stride_size_list[0],
215+
pad_size_list[1],
216+
)
217+
pad_size_list[3] = adjust_pooling_pad_if_needed(
218+
input_tensor.shape[3],
219+
kernel_size_list[1],
220+
stride_size_list[1],
221+
pad_size_list[3],
222+
)
223+
195224
attr = ts.TosaSerializerAttribute()
196225
attr.AvgPool2dAttribute(
197226
kernel=kernel_size_list,

backends/arm/operators/op_max_pool2d.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,14 @@
1717
register_node_visitor,
1818
)
1919
from executorch.backends.arm.operators.operator_validation_utils import (
20+
adjust_pooling_pad_if_needed,
2021
validate_num_inputs,
2122
validate_same_dtype,
2223
)
2324
from executorch.backends.arm.tosa_mapping import TosaArg
2425
from executorch.backends.arm.tosa_specification import TosaSpecification
2526

2627

27-
# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible:
28-
# `(input + 2 * pad - kernel_size) / stride`
29-
# PyTorch however, does not require this, so as needed, we must adjust the padding.
30-
def adjust_pad_if_needed(
31-
input_size: int, kernel_size: int, stride: int, pad: int
32-
) -> int:
33-
if pad == 0:
34-
return pad
35-
36-
mod_remainder = (input_size + 2 * pad - kernel_size) % stride
37-
38-
# No need to adjust
39-
if mod_remainder == 0:
40-
return pad
41-
42-
return pad - mod_remainder
43-
44-
4528
@register_node_visitor
4629
class MaxPool2dVisitor_0_80(NodeVisitor):
4730
target = "aten.max_pool2d.default"
@@ -82,13 +65,13 @@ def define_node(
8265
pad_size_list = [0, 0, 0, 0]
8366

8467
# Adjust the padding as necessary
85-
pad_size_list[1] = adjust_pad_if_needed(
68+
pad_size_list[1] = adjust_pooling_pad_if_needed(
8669
input_tensor.shape[2],
8770
kernel_size[0],
8871
stride[0],
8972
pad_size_list[1],
9073
)
91-
pad_size_list[3] = adjust_pad_if_needed(
74+
pad_size_list[3] = adjust_pooling_pad_if_needed(
9275
input_tensor.shape[3],
9376
kernel_size[1],
9477
stride[1],
@@ -167,13 +150,13 @@ def define_node(
167150
pad_size_list = [0, 0, 0, 0]
168151

169152
# Adjust the padding as necessary
170-
pad_size_list[1] = adjust_pad_if_needed(
153+
pad_size_list[1] = adjust_pooling_pad_if_needed(
171154
input_tensor.shape[2],
172155
kernel_size[0],
173156
stride[0],
174157
pad_size_list[1],
175158
)
176-
pad_size_list[3] = adjust_pad_if_needed(
159+
pad_size_list[3] = adjust_pooling_pad_if_needed(
177160
input_tensor.shape[3],
178161
kernel_size[1],
179162
stride[1],

backends/arm/operators/operator_validation_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any]):
9999
f"{op_name}: Expected all tensors to have dtype {reference_dtype}, but "
100100
f"found inconsistent dtype {tensor.dtype}."
101101
)
102+
103+
104+
def adjust_pooling_pad_if_needed(
105+
input_size: int, kernel_size: int, stride: int, pad: int
106+
) -> int:
107+
"""
108+
Calculates the padding that needs to be removed to a pooling window to make it
109+
divisible by the kernels stride. All inputs should correspond to the same dimension.
110+
111+
Parameters:
112+
-----------
113+
input_size : int
114+
The size of the input to the operator.
115+
116+
kernel_size : int
117+
The size of the kernel.
118+
119+
stride : int
120+
The size of the stride.
121+
122+
pad : int
123+
The amount of padding.
124+
125+
Output:
126+
-------
127+
An int, representing the padding to remove to make the window divisible.
128+
"""
129+
if pad == 0:
130+
return pad
131+
132+
mod_remainder = (input_size + 2 * pad - kernel_size) % stride
133+
134+
# No need to adjust
135+
if mod_remainder == 0:
136+
return pad
137+
138+
return pad - mod_remainder

backends/arm/quantizer/quantization_annotator.py

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
import torch.fx
13+
import torch.nn.functional as F
1314
from executorch.backends.arm.quantizer import QuantizationConfig
1415
from executorch.backends.arm.tosa_utils import get_node_debug_info
1516
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
@@ -142,29 +143,33 @@ def _match_pattern(
142143
143144
Each 'pattern' element is composed of a list of disjunctive nodes types.
144145
"""
145-
assert len(pattern) == 2, "Only two-nodes patterns supported currently"
146-
147-
if node.target in pattern[0]:
148-
assert len(node.users) != 0
149-
parent = node
150-
child = next(iter(node.users))
151-
elif node.target in pattern[1]:
152-
assert len(node.args) != 0
153-
parent = node.args[0] # type: ignore[assignment]
154-
child = node
155-
else:
156-
return False
157-
158-
if len(parent.users) != 1:
159-
return False
160-
161-
if parent.target not in pattern[0] or child.target not in pattern[1]:
162-
return False
163-
146+
assert len(pattern) > 0, "No pattern provided"
164147
if filter_fn is not None:
165-
return filter_fn(parent) and filter_fn(child)
166-
167-
return True
148+
if not filter_fn(node):
149+
return False
150+
if len(pattern) == 1:
151+
# Base case where it has passed the filter_fn. Simply look if node.target is in pattern.
152+
return node.target in pattern[0]
153+
if node.target not in [op for sub_pattern in pattern for op in sub_pattern]:
154+
# node.target not in pattern. No need to look at the rest of the pattern.
155+
return False
156+
# Find the index of this node's target in pattern
157+
idx = [node.target in sub_pattern for sub_pattern in pattern].index(True)
158+
left_pattern = pattern[:idx]
159+
# Exclude idx as this contains node.target which we have already matched
160+
right_pattern = pattern[idx + 1 :]
161+
left_condition = True
162+
right_condition = True
163+
# Recursively look at the rest of the pattern by calling this function for
164+
# node's input and user node with updated patterns.
165+
if len(left_pattern) > 0:
166+
parent = node.all_input_nodes[0]
167+
if len(parent.users) != 1:
168+
return False
169+
left_condition = _match_pattern(parent, left_pattern, filter_fn)
170+
if len(right_pattern) > 0:
171+
right_condition = _match_pattern(list(node.users)[0], right_pattern, filter_fn)
172+
return left_condition and right_condition
168173

169174

170175
_one_to_one = [
@@ -274,6 +279,58 @@ def any_or_hardtanh_min_zero(n: Node):
274279
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
275280

276281
if _match_pattern(
282+
node,
283+
[
284+
[
285+
torch.ops.aten.conv1d.default,
286+
torch.ops.aten.conv2d.default,
287+
torch.ops.aten.conv2d.padding,
288+
],
289+
[torch.ops.aten.batch_norm.default, F.batch_norm],
290+
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
291+
],
292+
filter_fn=any_or_hardtanh_min_zero,
293+
):
294+
if node.target in (
295+
torch.ops.aten.conv1d.default,
296+
torch.ops.aten.conv2d.default,
297+
torch.ops.aten.conv2d.padding,
298+
):
299+
quant_properties.quant_inputs = [
300+
_QuantProperty(0, input_act_qspec),
301+
_QuantProperty(1, weight_qspec, mark_annotated=True),
302+
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
303+
]
304+
elif node.target in (
305+
torch.ops.aten.relu.default,
306+
torch.ops.aten.hardtanh.default,
307+
):
308+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
309+
310+
elif _match_pattern(
311+
node,
312+
[
313+
[
314+
torch.ops.aten.conv1d.default,
315+
torch.ops.aten.conv2d.default,
316+
torch.ops.aten.conv2d.padding,
317+
],
318+
[torch.ops.aten.batch_norm.default, F.batch_norm],
319+
],
320+
):
321+
if node.target in (
322+
torch.ops.aten.conv1d.default,
323+
torch.ops.aten.conv2d.default,
324+
torch.ops.aten.conv2d.padding,
325+
):
326+
quant_properties.quant_inputs = [
327+
_QuantProperty(0, input_act_qspec),
328+
_QuantProperty(1, weight_qspec, mark_annotated=True),
329+
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
330+
]
331+
elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]:
332+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
333+
elif _match_pattern(
277334
node,
278335
[
279336
[
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
get_symmetric_quantization_config,
12+
TOSAQuantizer,
13+
)
14+
from executorch.backends.arm.test import common, conftest
15+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
16+
17+
from executorch.backends.xnnpack.test.tester.tester import Quantize
18+
from torch import nn
19+
20+
21+
input_t1 = Tuple[torch.Tensor] # Input x
22+
23+
24+
class ConvModule(torch.nn.Module):
25+
input_shape = (1, 28, 28)
26+
batch_size = 64
27+
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)
28+
29+
def __init__(self, batch_norm: bool = True) -> None:
30+
super().__init__()
31+
self.conv = torch.nn.Conv2d(1, 16, 3, stride=2)
32+
self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity()
33+
34+
def forward(self, x: torch.Tensor):
35+
x = self.conv(x)
36+
x = self.bn(x)
37+
x = F.relu(x)
38+
39+
return x
40+
41+
42+
models = {
43+
"conv_bn_relu": ConvModule(batch_norm=True),
44+
"conv_relu": ConvModule(batch_norm=False),
45+
}
46+
47+
48+
@common.parametrize("model", models)
49+
def test_qat_tosa_BI(model: torch.nn.Module):
50+
pipeline = TosaPipelineBI[input_t1](model, model.test_data, [], [], qtol=1)
51+
tosa_version = conftest.get_option("tosa_version")
52+
tosa_profiles = {
53+
"0.80": common.TosaSpecification.create_from_string("TOSA-0.80+BI"),
54+
"1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"),
55+
}
56+
tosa_spec = tosa_profiles[tosa_version]
57+
quantizer = TOSAQuantizer(tosa_spec)
58+
pipeline.change_args(
59+
"quantize",
60+
Quantize(
61+
quantizer=quantizer,
62+
quantization_config=get_symmetric_quantization_config(is_qat=True),
63+
is_qat=True,
64+
),
65+
)
66+
pipeline.run()

0 commit comments

Comments
 (0)