Skip to content

Commit 05277dd

Browse files
authored
[ExecuTorch][XNNPACK] validate quant params before lowering
there have been some invalid scales/zp sneaking into our serialized quant tensors. This has been annoying because it seemingly passes the export stage. Adding some asserts to help identify the node, the buggy scale/zp value, and the index in to that tensor if applicable Differential Revision: D70775365
1 parent 2155284 commit 05277dd

File tree

9 files changed

+374
-181
lines changed

9 files changed

+374
-181
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
op_ceiling,
1616
op_clamp,
1717
op_conv2d,
18-
op_dequantize_per_tensor,
1918
op_div,
2019
op_dynamic_dequantize_ops,
2120
op_dynamic_quantize_ops,
@@ -35,7 +34,7 @@
3534
op_negate,
3635
op_permute,
3736
op_prelu,
38-
op_quantize_per_tensor,
37+
op_quant_dequant,
3938
op_relu,
4039
op_rsqrt,
4140
op_sdpa,

backends/xnnpack/operators/op_dequantize_per_tensor.py

Lines changed: 0 additions & 70 deletions
This file was deleted.
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
11+
TagImplicitQDqPass,
12+
)
13+
from executorch.backends.xnnpack.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
18+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
19+
XNNConvert,
20+
XNNGraph,
21+
XNode,
22+
)
23+
from executorch.backends.xnnpack.utils.quant_utils import (
24+
is_per_channel_group,
25+
validate_quant_scales,
26+
validate_quant_zeropoints,
27+
)
28+
from executorch.backends.xnnpack.utils.utils import get_input_node, get_param_tensor
29+
30+
31+
class OpStaticQDQNode(NodeVisitor):
32+
def check_scales_zeropoints(self, node) -> None:
33+
scales = node.args[1]
34+
zero_points = node.args[2]
35+
is_groupwise = is_per_channel_group(node)
36+
dtype = node.args[-1]
37+
if is_groupwise:
38+
dtype = node.args[-3]
39+
40+
if isinstance(scales, torch.fx.Node):
41+
scales = get_param_tensor(self.exported_program, scales)
42+
43+
if isinstance(zero_points, torch.fx.Node):
44+
zero_points = get_param_tensor(self.exported_program, zero_points)
45+
46+
try:
47+
validate_quant_scales(scales)
48+
validate_quant_zeropoints(zero_points, dtype, is_groupwise)
49+
except ValueError as e:
50+
raise ValueError(
51+
f"Invalid quantization scale or zero point for {node}: {e}"
52+
)
53+
54+
def define_node(
55+
self,
56+
node: torch.fx.Node,
57+
xnn_graph: XNNGraph,
58+
vals_to_ids: Dict[torch.fx.Node, int],
59+
debug_handle: int,
60+
) -> None:
61+
# check scales and zp are valid
62+
self.check_scales_zeropoints(node)
63+
64+
65+
@register_node_visitor
66+
class OpDeQuantizePerTensor(OpStaticQDQNode):
67+
"""
68+
Dequantize Per Tensor Node visitor
69+
"""
70+
71+
target = "quantized_decomposed.dequantize_per_tensor.default"
72+
73+
def __init__(self, *args) -> None:
74+
super().__init__(*args)
75+
76+
def define_node(
77+
self,
78+
node: torch.fx.Node,
79+
xnn_graph: XNNGraph,
80+
vals_to_ids: Dict[torch.fx.Node, int],
81+
debug_handle: int,
82+
) -> None:
83+
"""
84+
We only define a node if it is not an implict dq node
85+
"""
86+
# check scales and zp are valid
87+
super().define_node(node, xnn_graph, vals_to_ids, debug_handle)
88+
89+
if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node):
90+
dq_input = get_input_node(node, 0)
91+
input_quant_params = QuantParams.from_q_dq_node(node)
92+
# fp32 output
93+
self.define_tensor(node, xnn_graph, vals_to_ids)
94+
output_id = vals_to_ids[node]
95+
96+
# qint8 input
97+
input_quant_params.is_output = False
98+
self.define_tensor(
99+
dq_input, xnn_graph, vals_to_ids, quant_params=input_quant_params
100+
)
101+
input_id = vals_to_ids[dq_input]
102+
103+
ser_node = XNode(
104+
xnode_union=XNNConvert(input_id=input_id, output_id=output_id, flags=0),
105+
debug_handle=debug_handle,
106+
)
107+
xnn_graph.xnodes.append(ser_node)
108+
else:
109+
# If this node was ignored, then its id is the same as its parent
110+
dq_input = get_input_node(node, 0)
111+
if dq_input in vals_to_ids:
112+
vals_to_ids[node] = vals_to_ids[dq_input]
113+
114+
115+
@register_node_visitor
116+
class OpQuantizePerTensor(OpStaticQDQNode):
117+
"""
118+
Quantize Per Tensor Node visitor
119+
"""
120+
121+
target = "quantized_decomposed.quantize_per_tensor.default"
122+
123+
def __init__(self, *args) -> None:
124+
super().__init__(*args)
125+
126+
def define_node(
127+
self,
128+
node: torch.fx.Node,
129+
xnn_graph: XNNGraph,
130+
vals_to_ids: Dict[torch.fx.Node, int],
131+
debug_handle: int,
132+
) -> None:
133+
"""
134+
We only define a node if it is not an implict q node
135+
"""
136+
# check scales and zp are valid
137+
super().define_node(node, xnn_graph, vals_to_ids, debug_handle)
138+
139+
q_input = get_input_node(node, 0)
140+
if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node):
141+
input_quant_params = QuantParams.from_q_dq_node(node)
142+
# fp32 input
143+
self.define_tensor(q_input, xnn_graph, vals_to_ids)
144+
input_id = vals_to_ids[q_input]
145+
146+
# qint8 output
147+
input_quant_params.q_input = node
148+
input_quant_params.is_input = False
149+
self.define_tensor(
150+
node, xnn_graph, vals_to_ids, quant_params=input_quant_params
151+
)
152+
output_id = vals_to_ids[node]
153+
154+
ser_node = XNode(
155+
xnode_union=XNNConvert(input_id=input_id, output_id=output_id, flags=0),
156+
debug_handle=debug_handle,
157+
)
158+
xnn_graph.xnodes.append(ser_node)
159+
else:
160+
# If this node was ignored, then its id is the same as its parents
161+
if q_input in vals_to_ids:
162+
vals_to_ids[node] = vals_to_ids[q_input]
163+
164+
165+
@register_node_visitor
166+
class OpDequantizePerChannelDefault(OpStaticQDQNode):
167+
"""
168+
do nothing if node is dequantize_per_channel.default
169+
"""
170+
171+
target = "quantized_decomposed.dequantize_per_channel.default"
172+
173+
174+
@register_node_visitor
175+
class OpQuantizePerChannelDefault(OpStaticQDQNode):
176+
"""
177+
do nothing if node is quantize_per_channel.default
178+
"""
179+
180+
target = "quantized_decomposed.quantize_per_channel.default"
181+
182+
183+
@register_node_visitor
184+
class OpQuantizePerChannelGroupDefault(OpStaticQDQNode):
185+
"""
186+
do nothing if node is quantize_per_channel_group.default
187+
"""
188+
189+
target = "quantized_decomposed.quantize_per_channel_group.default"
190+
191+
192+
@register_node_visitor
193+
class OpDequantizePerChannelGroupDefault(OpStaticQDQNode):
194+
"""
195+
do nothing if node is dequantize_per_channel_group.default
196+
"""
197+
198+
target = "quantized_decomposed.dequantize_per_channel_group.default"

backends/xnnpack/operators/op_quantize_per_tensor.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

0 commit comments

Comments
 (0)