Skip to content

Commit d3e3389

Browse files
vkuzopytorchmergebot
authored andcommitted
ns for fx: skip shadowing for torch.cat, and also for nodes with only kwargs (pytorch#76561)
Summary: Pull Request resolved: pytorch#76561 User model had syntax like `torch.cat(tensors=[x])`. This PR fixes two errors to unbreak this in NS shadow model: 1. skip nodes which only have kwargs (instead of throwing an exception) 2. explicitly skip shadowing of `torch.cat` (since it's not supported anyways) Test Plan: ``` python test/test_quantization.py -k test_op_with_only_kwargs_skips_shadowing python test/test_quantization.py -k test_op_mul_add_cat_skips_shadowing ``` Reviewed By: hx89 Differential Revision: D36017356 Pulled By: vkuzo fbshipit-source-id: 0da4840a62c2dac183f8294c2cec4fce262474b3 (cherry picked from commit 88409c1)
1 parent 73b33de commit d3e3389

File tree

3 files changed

+39
-44
lines changed

3 files changed

+39
-44
lines changed

test/quantization/fx/test_numeric_suite_fx.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,33 +1171,6 @@ def forward(self, x):
11711171
# that shadowing models with method calls does not crash.
11721172
results_len=0)
11731173

1174-
@skipIfNoFBGEMM
1175-
def test_add_shadow_loggers_multiple_dtype_casts(self):
1176-
"""
1177-
Verifies that for nodes where the first input arg is a list,
1178-
such as `cat`, we insert an individual dtype cast for each
1179-
arg of the list.
1180-
"""
1181-
class M(nn.Module):
1182-
def __init__(self):
1183-
super().__init__()
1184-
1185-
def forward(self, x):
1186-
x = torch.cat([x, x, x], dim=0)
1187-
return x
1188-
1189-
m = M().eval()
1190-
expected_occurrence = {
1191-
# 3 dequantize function calls from the 3 dtype casts for [x, x, x]
1192-
ns.call_module(torch.nn.Identity): 3,
1193-
# 1 dequantize method call for module output
1194-
ns.call_method("dequantize"): 1,
1195-
}
1196-
self._test_match_shadow_activations(
1197-
m, (torch.randn(4, 4),),
1198-
prepared_expected_node_occurrence=expected_occurrence,
1199-
results_len=1, compare_fp32_vs_fp32_prepared=False)
1200-
12011174
@skipIfNoFBGEMM
12021175
def test_shadow_activations_fqn(self):
12031176
m = nn.Sequential(
@@ -1237,7 +1210,7 @@ def forward(self, x):
12371210
m = M().eval()
12381211
self._test_match_shadow_activations(
12391212
m, (torch.randn(1, 1, 4, 4),),
1240-
results_len=2,
1213+
results_len=1,
12411214
should_log_inputs=True)
12421215

12431216
@skipIfNoFBGEMM
@@ -1954,13 +1927,27 @@ def test_fp16_shadows_fp32(self):
19541927
mq = convert_fx(mp, is_reference=True)
19551928
mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger)
19561929

1957-
def test_mul_add_skips_shadowing(self):
1930+
def test_mul_add_cat_stack_skips_shadowing(self):
19581931
class M(nn.Module):
19591932
def forward(self, x):
19601933
x = x * x
19611934
x = torch.mul(x, x)
19621935
x = x + x
19631936
x = torch.add(x, x)
1937+
x = torch.cat([x])
1938+
x = torch.stack([x])
1939+
return x
1940+
1941+
m = M().eval()
1942+
self._test_match_shadow_activations(
1943+
m, (torch.randn(1, 1, 4, 4),),
1944+
results_len=0)
1945+
1946+
def test_op_with_only_kwargs_skips_shadowing(self):
1947+
class M(nn.Module):
1948+
def forward(self, x):
1949+
x = torch.cat(tensors=[x])
1950+
x = torch.stack(tensors=[x])
19641951
return x
19651952

19661953
m = M().eval()
@@ -2101,7 +2088,7 @@ def test_sparsenn_shadow(self):
21012088
x = torch.randn(2, 4)
21022089
self._test_match_shadow_activations(
21032090
sparse_nn, (idx, offsets, x),
2104-
results_len=4,
2091+
results_len=3,
21052092
should_log_inputs=should_log_inputs)
21062093

21072094
@skip_if_no_torchvision

torch/ao/ns/fx/graph_passes.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,26 @@ def load_arg(a):
602602
subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
603603
end_node_b_to_matched_subgraph_a_and_name[node_b]
604604

605+
if len(node_b.args) == 0:
606+
print(
607+
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
608+
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
609+
', kwargs-only node not handled yet')
610+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
611+
continue
612+
613+
all_op_types_support_shadowing = (
614+
op_type_supports_shadowing(subgraph_a.start_node) and
615+
op_type_supports_shadowing(node_b)
616+
)
617+
if not all_op_types_support_shadowing:
618+
print(
619+
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
620+
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
621+
', unsupported')
622+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
623+
continue
624+
605625
# For both start_node and end_node verify that we know how to do
606626
# the dtype cast. If we do not, skip.
607627
node_input_type_a, node_output_type_a = \
@@ -626,18 +646,6 @@ def load_arg(a):
626646
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
627647
continue
628648

629-
all_op_types_support_shadowing = (
630-
op_type_supports_shadowing(subgraph_a.start_node) and
631-
op_type_supports_shadowing(node_b)
632-
)
633-
if not all_op_types_support_shadowing:
634-
print(
635-
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
636-
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
637-
', unsupported')
638-
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
639-
continue
640-
641649
# If we are shadowing from fp32 to int8, we need to insert
642650
# quantize_per_tensor call with qparams from the previous node.
643651
# Only do this if we are able to infer these qparams from the graph.

torch/ao/ns/fx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
492492

493493
def op_type_supports_shadowing(node: Node) -> bool:
494494
if node.op == 'call_function':
495-
if node.target in (torch.add, torch.mul, operator.add, operator.mul):
496-
# shadowing for ops with two inputs is not implemented yet
495+
if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
496+
# shadowing for ops with multiple tensor inputs is not implemented yet
497497
return False
498498
return True

0 commit comments

Comments
 (0)