12
12
from executorch .exir .backend .canonical_partitioners .config_partitioner import (
13
13
format_target_name ,
14
14
)
15
+ from torch .fx .experimental .symbolic_shapes import free_symbols , has_free_symbols
15
16
16
17
_Q_OPS = {
17
18
"quantize_per_tensor.tensor" ,
@@ -126,8 +127,8 @@ def is_affine_qdq(node: torch.fx.Node) -> bool:
126
127
def _get_block_size_input_scale (node : torch .fx .Node ):
127
128
assert is_affine_qdq (node )
128
129
block_size = node .args [1 ]
129
- input_val = node .all_input_nodes [0 ].meta ["val" ]
130
- scale_val = node .all_input_nodes [ 1 ] .meta ["val" ]
130
+ input_val = cast ( torch . fx . Node , node .args [0 ]) .meta ["val" ]
131
+ scale_val = cast ( torch . fx . Node , node .args [ 2 ]) .meta ["val" ]
131
132
return block_size , input_val , scale_val
132
133
133
134
@@ -145,7 +146,21 @@ def is_per_token(node: torch.fx.Node):
145
146
flag &= block_size [i ] == 1
146
147
scale_numel_expected *= input_val .shape [i ]
147
148
148
- flag &= block_size [- 1 ] == input_val .shape [- 1 ]
149
+ ic_block_size = block_size [- 1 ]
150
+ if isinstance (ic_block_size , torch .fx .Node ):
151
+ ic_block_size = ic_block_size .meta ["val" ]
152
+ assert free_symbols (
153
+ ic_block_size
154
+ ), f"block_size: { block_size } given, but { block_size [- 1 ]} is not a dynamic symint"
155
+
156
+ ic_dim = input_val .shape [- 1 ]
157
+ if isinstance (ic_dim , torch .fx .Node ):
158
+ ic_dim = ic_dim .meta ["val" ]
159
+ assert free_symbols (
160
+ ic_dim
161
+ ), f"input_shape: { input_val .shape } given, but { input_val .shape [- 1 ]} is not a dynamic symint"
162
+
163
+ flag &= ic_dim == ic_block_size
149
164
flag &= scale_val .numel () == scale_numel_expected
150
165
return flag
151
166
@@ -160,6 +175,11 @@ def is_per_channel_group(node: torch.fx.Node):
160
175
return True
161
176
elif is_affine_qdq (node ):
162
177
block_size , input_val , scale_val = _get_block_size_input_scale (node )
178
+ # per channel group is only valid on static weights
179
+ # so scales and weights can't have dynamic shape
180
+ if has_free_symbols (input_val .shape ) or has_free_symbols (scale_val .shape ):
181
+ return False
182
+
163
183
flag = True
164
184
flag &= len (block_size ) == 2
165
185
flag &= block_size [0 ] == 1
0 commit comments