Skip to content

Commit 3a8752d

Browse files
vkuzopytorchmergebot
authored andcommitted
ns for fx: skip shadowing ops if copy subgraph is not implemented (pytorch#76663)
Summary: Pull Request resolved: pytorch#76663 Subgraph copy does not handle all edge cases. It's high eng time to handle them all, and currently an unhandled edge case crashes the script. This PR adds a function to check if the subgraph copy is supported, and skips shadowing if it is not supported. This way the model can still go through the shadowing APIs without an exception. Test Plan: ``` python test/test_quantization.py -k FXNumericSuite ``` Reviewed By: hx89 Differential Revision: D36069304 Pulled By: vkuzo fbshipit-source-id: 6b38b8d8e43396a4cf2373b247223a19d451d096 (cherry picked from commit e2322ca)
1 parent d3e3389 commit 3a8752d

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

test/quantization/fx/test_numeric_suite_fx.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,23 @@ def forward(self, x):
19551955
m, (torch.randn(1, 1, 4, 4),),
19561956
results_len=0)
19571957

1958+
def test_unsupported_op_copy_skips_shadowing(self):
1959+
"""
1960+
Copying a `call_function` node is not implemented, test that this
1961+
does not crash shadowing but instead skips the node.
1962+
"""
1963+
class M(nn.Module):
1964+
def forward(self, x):
1965+
# the second argument leads to attempting to copy a
1966+
# call_function node
1967+
x = F.layer_norm(x, x.shape[1:])
1968+
return x
1969+
1970+
m = M().eval()
1971+
self._test_match_shadow_activations(
1972+
m, (torch.randn(1, 1, 4, 4),),
1973+
results_len=0)
1974+
19581975

19591976
class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
19601977
"""

torch/ao/ns/fx/graph_passes.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,53 @@ def _copy_node_from_a_to_c(
360360
raise AssertionError(
361361
f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")
362362

363+
def _can_insert_copy_of_subgraph_a(
364+
subgraph_a: NSSubgraph,
365+
gm_a: GraphModule,
366+
num_non_param_args_node_a: int,
367+
) -> bool:
368+
"""
369+
This function returns `False` if the input subgraph cannot be copied by
370+
`_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
371+
that there is a corner case logic for which copy is not yet implemented.
372+
"""
373+
# populate the list of nodes we need to check
374+
nodes = []
375+
cur_node = subgraph_a.end_node
376+
while cur_node != subgraph_a.start_node:
377+
nodes.append(cur_node)
378+
cur_node = cur_node.args[0] # type: ignore[assignment]
379+
nodes.append(cur_node)
380+
nodes.reverse()
381+
382+
def _can_insert(node_a_arg, gm_a):
383+
if isinstance(node_a_arg, Node):
384+
arg_a = return_first_non_observer_node(node_a_arg, gm_a)
385+
if arg_a.op == 'call_method':
386+
return arg_a.target in ('dequantize', 'to')
387+
elif arg_a.op == 'get_attr':
388+
return True
389+
else:
390+
return False
391+
elif isinstance(node_a_arg, (list, tuple)):
392+
for el in node_a_arg:
393+
if not isinstance(el, Node):
394+
return False
395+
return True
396+
397+
# For each node, check if we handle the copy behavior. This follows the
398+
# logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
399+
for node_a_arg in nodes[0].args[num_non_param_args_node_a:]:
400+
if not _can_insert(node_a_arg, gm_a):
401+
return False
402+
403+
for node in nodes[1:]:
404+
for node_a_arg in node.args[1:]:
405+
if not _can_insert(node_a_arg, gm_a):
406+
return False
407+
408+
return True
409+
363410
def _insert_copy_of_subgraph_a_after_input_node_c(
364411
input_node_c: Union[Node, List[Node]],
365412
input_node_c_2: Optional[Union[Node, List[Node]]],
@@ -663,6 +710,16 @@ def load_arg(a):
663710
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
664711
continue
665712

713+
num_non_param_args_node_a = \
714+
get_number_of_non_param_args(subgraph_a.start_node, gm_a)
715+
if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
716+
print(
717+
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
718+
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
719+
', unhandled logic in subgraph copy')
720+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
721+
continue
722+
666723
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
667724
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)
668725

0 commit comments

Comments
 (0)