Skip to content

Commit 59000cf

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][fx][graphmode] Add support for conv add pattern in backend_config_dict (pytorch#69778)
Summary: Pull Request resolved: pytorch#69778 This PR extends fusion pattern support from simple sequence of ops to a simple subgraph like conv - add ``` x - conv ---\ y ---------add ---- ouptut ``` where input x, y and output are observed/quantized Test Plan: ``` python test/fx2trt/test_quant_trt.py TestQuantizeFxTRTOps.test_conv_add ``` Imported from OSS Reviewed By: vkuzo Differential Revision: D33024528 fbshipit-source-id: 5c770c82c8f693fabdac5c69343942a9dfda84ef
1 parent 4082833 commit 59000cf

File tree

4 files changed

+56
-10
lines changed

4 files changed

+56
-10
lines changed

test/fx2trt/test_quant_trt.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from torch.ao.quantization._quantize_fx_do_not_use import (
1818
_convert_fx_do_not_use,
1919
)
20+
from torch.ao.quantization.fx.match_utils import (
21+
MatchAllNode,
22+
)
2023
from torch.testing._internal.common_quantization import (
2124
QuantizationTestCase,
2225
)
@@ -27,6 +30,8 @@
2730
from torch.testing._internal.common_quantization import NodeSpec as ns
2831
import unittest
2932
import itertools
33+
import copy
34+
import operator
3035

3136
def lower_to_trt(model, inputs, shape_ranges):
3237
""" Lower a quantized model to TensorRT
@@ -375,5 +380,51 @@ def forward(self, x):
375380
}
376381
self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
377382

383+
def test_conv_add(self):
384+
class M(torch.nn.Module):
385+
def __init__(self):
386+
super().__init__()
387+
self.conv = torch.nn.Conv2d(3, 3, 3)
388+
389+
def forward(self, x, y):
390+
return self.conv(x) + y
391+
392+
from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType
393+
weighted_op_qint8_dtype_config = {
394+
# optional, input activation dtype
395+
"input_dtype": torch.qint8,
396+
# optional, weight dtype
397+
"weight_dtype": torch.qint8,
398+
# optional, bias dtype
399+
"bias_dtype": torch.float,
400+
# optional, output activation dtype
401+
"output_dtype": torch.qint8
402+
}
403+
404+
conv_add_config = {
405+
"pattern": (operator.add, torch.nn.Conv2d, MatchAllNode),
406+
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
407+
"dtype_configs": [
408+
weighted_op_qint8_dtype_config,
409+
],
410+
"root_module": torch.nn.Conv2d,
411+
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
412+
}
413+
414+
m = M().eval()
415+
modified_backend_config_dict = copy.deepcopy(self.trt_backend_config_dict)
416+
modified_backend_config_dict["configs"].insert(0, conv_add_config)
417+
m = prepare_fx(m, {"": self.qconfig}, backend_config_dict=modified_backend_config_dict)
418+
node_occurrence = {
419+
ns.call_module(torch.ao.quantization.HistogramObserver): 3,
420+
}
421+
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
422+
m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=modified_backend_config_dict)
423+
node_occurrence = {
424+
ns.call_function(torch.quantize_per_tensor): 3,
425+
ns.call_method("dequantize"): 3,
426+
}
427+
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
428+
378429
if __name__ == "__main__":
379430
run_tests()

torch/ao/quantization/fx/_convert_do_not_use.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
)
1616
from .backend_config_dict.utils import get_quantized_reference_module_mapping
1717

18-
from .match_utils import (
19-
find_matches,
20-
)
2118
from .graph_module import (
2219
QuantizedGraphModule,
2320
)
@@ -103,10 +100,6 @@ def _convert_do_not_use(
103100
custom_module_classes = get_custom_module_class_keys(
104101
convert_custom_config_dict,
105102
"observed_to_quantized_custom_module_class")
106-
matches = find_matches(
107-
model.graph, modules, patterns,
108-
qconfig_map,
109-
custom_module_classes=custom_module_classes)
110103

111104
if model._equalization_qconfig_map is not None:
112105
# If we want to do equalization then do the following:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ..fusion_patterns import DefaultFuseHandler
22

3-
# TODO: move ModuleReLUFusion here
3+
# TODO: move DefaultFuseHandler
44
def get_fuse_handler_cls():
55
return DefaultFuseHandler

torch/ao/quantization/fx/prepare.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,10 @@ def is_pattern_dtype_config_supported_by_backend(
174174
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config_dict)
175175
dtype_configs: List[Dict[str, torch.dtype]] = pattern_to_dtype_configs.get(pattern, [])
176176

177-
input_node = matched_nodes[0]
178-
output_node = matched_nodes[-1]
177+
# TODO: this only checks one input and one output, need to generalize to multiple
178+
# inputs/output
179+
input_node = matched_nodes[-1]
180+
output_node = matched_nodes[0]
179181
for dtype_config in dtype_configs:
180182
# check if arg dtype are supported
181183
supported = True

0 commit comments

Comments
 (0)