Skip to content

Commit eb44899

Browse files
mcr229facebook-github-bot
authored andcommitted
per_channel_group can't be dynamic
Summary: There are some dynamism issues that arise when checking the semantics of quantize_affine nodes. We avoid them by accounting for free_symbols. Differential Revision: D72488540
1 parent 2d56897 commit eb44899

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

backends/xnnpack/utils/quant_utils.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
1313
format_target_name,
1414
)
15+
from torch.fx.experimental.symbolic_shapes import free_symbols, has_free_symbols
1516

1617
_Q_OPS = {
1718
"quantize_per_tensor.tensor",
@@ -126,8 +127,8 @@ def is_affine_qdq(node: torch.fx.Node) -> bool:
126127
def _get_block_size_input_scale(node: torch.fx.Node):
127128
assert is_affine_qdq(node)
128129
block_size = node.args[1]
129-
input_val = node.all_input_nodes[0].meta["val"]
130-
scale_val = node.all_input_nodes[1].meta["val"]
130+
input_val = cast(torch.fx.Node, node.args[0]).meta["val"]
131+
scale_val = cast(torch.fx.Node, node.args[2]).meta["val"]
131132
return block_size, input_val, scale_val
132133

133134

@@ -145,7 +146,21 @@ def is_per_token(node: torch.fx.Node):
145146
flag &= block_size[i] == 1
146147
scale_numel_expected *= input_val.shape[i]
147148

148-
flag &= block_size[-1] == input_val.shape[-1]
149+
ic_block_size = block_size[-1]
150+
if isinstance(ic_block_size, torch.fx.Node):
151+
ic_block_size = ic_block_size.meta["val"]
152+
assert free_symbols(
153+
ic_block_size
154+
), f"block_size: {block_size} given, but {block_size[-1]} is not a dynamic symint"
155+
156+
ic_dim = input_val.shape[-1]
157+
if isinstance(ic_dim, torch.fx.Node):
158+
ic_dim = ic_dim.meta["val"]
159+
assert free_symbols(
160+
ic_dim
161+
), f"input_shape: {input_val.shape} given, but {input_val.shape[-1]} is not a dynamic symint"
162+
163+
flag &= ic_dim == ic_block_size
149164
flag &= scale_val.numel() == scale_numel_expected
150165
return flag
151166

@@ -160,6 +175,11 @@ def is_per_channel_group(node: torch.fx.Node):
160175
return True
161176
elif is_affine_qdq(node):
162177
block_size, input_val, scale_val = _get_block_size_input_scale(node)
178+
# per channel group is only valid on static weights
179+
# so scales and weights can't have dynamic shape
180+
if has_free_symbols(input_val.shape) or has_free_symbols(scale_val.shape):
181+
return False
182+
163183
flag = True
164184
flag &= len(block_size) == 2
165185
flag &= block_size[0] == 1

0 commit comments

Comments
 (0)