Skip to content

Commit 5e281f8

Browse files
jainapurvafacebook-github-bot
authored andcommitted
Updates to use torchao's updated choose_qparams_affine and quantize/dequantize_affine (#11070)
Summary: Pull Request resolved: #11070 Updates to use torchao's updated choose_qparams_affine and quantize/dequantize_affine without the zero_point_domain arg Differential Revision: D75228037
1 parent 6357580 commit 5e281f8

File tree

3 files changed

+6
-13
lines changed

3 files changed

+6
-13
lines changed

backends/xnnpack/utils/quant_utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,22 @@ def is_dynamic_qdq(node: torch.fx.Node) -> bool:
5252
if not (is_quant(node) or is_dequant(node)):
5353
return False
5454

55-
# check scales and zp are dynamically chosen
55+
# check scales are dynamically chosen
5656
node_input_args = node.args
5757
if is_affine_qdq(node):
5858
node_input_args = extract_qdq_affine_op_args_for_decomposed_ops(node)
5959

6060
scale = node_input_args[1]
61-
zp = node_input_args[2]
62-
if not (isinstance(scale, torch.fx.Node) and isinstance(zp, torch.fx.Node)):
61+
62+
if not (isinstance(scale, torch.fx.Node)):
6363
return False
6464

65-
if not (scale.target == operator.getitem and zp.target == operator.getitem):
65+
if not (scale.target == operator.getitem):
6666
return False
6767

6868
scale_choose_qparam = scale.all_input_nodes[0]
69-
zp_choose_qparam = zp.all_input_nodes[0]
7069

71-
if not (is_qparam(scale_choose_qparam) and is_qparam(zp_choose_qparam)):
70+
if not (is_qparam(scale_choose_qparam)):
7271
return False
7372

7473
return True
@@ -222,9 +221,6 @@ def extract_qdq_affine_op_args_for_decomposed_ops(node: torch.fx.Node):
222221

223222
# add target_dtype_node after quant_min/quant_max
224223
args.append(target_dtype)
225-
# zero_point_domain
226-
if len(node.args) > 7 and node.args[7] != "INT":
227-
return None, None
228224

229225
if is_per_channel_group(node):
230226
block_sizes = cast(list[int], node.args[1])

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,6 @@ def embedding_byte_dtype_pattern(
10171017
torch.int8,
10181018
-128,
10191019
127,
1020-
"INT",
10211020
output_dtype,
10221021
)
10231022
return torch.ops.aten.embedding.default(dq, indices)
@@ -1062,7 +1061,6 @@ def embedding_2bit_dtype_pattern(
10621061
torch.int8,
10631062
-2,
10641063
1,
1065-
"INT",
10661064
output_dtype,
10671065
)
10681066
return torch.ops.aten.embedding.default(dq, indices)
@@ -1110,7 +1108,6 @@ def embedding_4bit_dtype_pattern(
11101108
torch.int8,
11111109
-8,
11121110
7,
1113-
"INT",
11141111
output_dtype,
11151112
)
11161113
return torch.ops.aten.embedding.default(dq, indices)

0 commit comments

Comments
 (0)