Skip to content

Commit 3decd5b

Browse files
cccclaifacebook-github-bot
authored andcommitted
Patch the _is_conv_node function
Summary: Add the conv padding ops in torch/ao only, will add a separate PR for the ones in pytorch Differential Revision: D75323215
1 parent efac465 commit 3decd5b

File tree

2 files changed

+141
-20
lines changed

2 files changed

+141
-20
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 138 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,6 +2478,27 @@ def forward(self, x):
24782478
node_list,
24792479
)
24802480

2481+
example_inputs = (torch.randn(1, 3, 5, 5),)
2482+
node_occurrence = {
2483+
# two for input of the first conv, one for output for the first conv
2484+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2485+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2486+
}
2487+
node_list = [
2488+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2489+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2490+
torch.ops.aten.conv2d.padding,
2491+
torch.ops.aten.relu.default,
2492+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2493+
]
2494+
self._test_quantizer(
2495+
TestHelperModules.ConvWithBNRelu(dim=2, relu=True, bn=True, padding="same"),
2496+
example_inputs,
2497+
BackendAQuantizer(),
2498+
node_occurrence,
2499+
node_list,
2500+
)
2501+
24812502
def test_conv_transpose3d_bn_relu(self):
24822503
class BackendAQuantizer(Quantizer):
24832504
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -2549,27 +2570,124 @@ def __init__(self):
25492570
def forward(self, x):
25502571
return torch.nn.functional.relu(self.bn(self.conv_t(x)))
25512572

2552-
example_inputs = (torch.randn(1, 2, 2, 5, 5),)
2553-
node_occurrence = {
2554-
# two for input of the first conv, one for output for the first conv
2555-
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2556-
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2557-
}
2558-
node_list = [
2559-
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2560-
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2561-
torch.ops.aten.conv_transpose3d.input,
2562-
torch.ops.aten.relu.default,
2563-
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2573+
def test_conv_padding_bn_relu(self):
2574+
class BackendAQuantizer(Quantizer):
2575+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2576+
act_qspec = QuantizationSpec(
2577+
dtype=torch.uint8,
2578+
quant_min=0,
2579+
quant_max=255,
2580+
qscheme=torch.per_tensor_affine,
2581+
is_dynamic=False,
2582+
observer_or_fake_quant_ctr=observer.default_observer,
2583+
)
2584+
weight_qspec = QuantizationSpec(
2585+
dtype=torch.int8,
2586+
quant_min=-128,
2587+
quant_max=127,
2588+
qscheme=torch.per_tensor_affine,
2589+
is_dynamic=False,
2590+
observer_or_fake_quant_ctr=observer.default_weight_observer,
2591+
)
2592+
bias_qspec = QuantizationSpec(
2593+
dtype=torch.float32,
2594+
is_dynamic=False,
2595+
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
2596+
)
2597+
2598+
for n in model.graph.nodes:
2599+
if (
2600+
n.op != "call_function"
2601+
or n.target != torch.ops.aten.relu.default
2602+
):
2603+
continue
2604+
relu_node = n
2605+
n = n.args[0]
2606+
2607+
# Check for any of the conv operations
2608+
conv_ops = [
2609+
torch.ops.aten.conv1d.padding,
2610+
torch.ops.aten.conv2d.padding,
2611+
torch.ops.aten.conv3d.padding
2612+
]
2613+
if n.op != "call_function" or n.target not in conv_ops:
2614+
continue
2615+
2616+
conv_node = n
2617+
input_act = conv_node.args[0]
2618+
weight = conv_node.args[1]
2619+
bias = conv_node.args[2]
2620+
conv_node.meta["quantization_annotation"] = (
2621+
QuantizationAnnotation(
2622+
input_qspec_map={
2623+
input_act: act_qspec,
2624+
weight: weight_qspec,
2625+
bias: bias_qspec,
2626+
},
2627+
_annotated=True,
2628+
)
2629+
)
2630+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
2631+
output_qspec=act_qspec,
2632+
_annotated=True,
2633+
)
2634+
2635+
def validate(self, model: torch.fx.GraphModule) -> None:
2636+
pass
2637+
2638+
# Test cases for Conv1d, Conv2d, Conv3d
2639+
test_cases = [
2640+
{
2641+
"conv_type": torch.nn.Conv1d,
2642+
"bn_type": torch.nn.BatchNorm1d,
2643+
"example_input": (torch.randn(1, 3, 5),),
2644+
"conv_op": torch.ops.aten.conv1d.padding,
2645+
},
2646+
{
2647+
"conv_type": torch.nn.Conv2d,
2648+
"bn_type": torch.nn.BatchNorm2d,
2649+
"example_input": (torch.randn(1, 3, 5, 5),),
2650+
"conv_op": torch.ops.aten.conv2d.padding,
2651+
},
2652+
{
2653+
"conv_type": torch.nn.Conv3d,
2654+
"bn_type": torch.nn.BatchNorm3d,
2655+
"example_input": (torch.randn(1, 3, 5, 5, 5),),
2656+
"conv_op": torch.ops.aten.conv3d.padding,
2657+
},
25642658
]
2565-
model = M().eval()
2566-
self._test_quantizer(
2567-
model,
2568-
example_inputs,
2569-
BackendAQuantizer(),
2570-
node_occurrence,
2571-
node_list,
2572-
)
2659+
2660+
for test_case in test_cases:
2661+
with self.subTest(conv_type=test_case["conv_type"].__name__):
2662+
class M(torch.nn.Module):
2663+
def __init__(self):
2664+
super().__init__()
2665+
self.conv = test_case["conv_type"](3, 3, 3, padding="same")
2666+
self.bn = test_case["bn_type"](3)
2667+
2668+
def forward(self, x):
2669+
return torch.nn.functional.relu(self.bn(self.conv(x)))
2670+
2671+
node_occurrence = {
2672+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2673+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2674+
}
2675+
node_list = [
2676+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2677+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2678+
test_case["conv_op"],
2679+
torch.ops.aten.relu.default,
2680+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2681+
]
2682+
2683+
model = M().eval()
2684+
self._test_quantizer(
2685+
model,
2686+
test_case["example_input"],
2687+
BackendAQuantizer(),
2688+
node_occurrence,
2689+
node_list,
2690+
)
25732691

25742692
def test_multi_users_without_output_observer(self):
25752693
"""

torchao/quantization/pt2e/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,11 @@ def _is_conv_node(n: Node):
625625
"""
626626
return n.op == "call_function" and n.target in [
627627
torch.ops.aten.conv1d.default,
628+
torch.ops.aten.conv1d.padding,
628629
torch.ops.aten.conv2d.default,
630+
torch.ops.aten.conv2d.padding,
629631
torch.ops.aten.conv3d.default,
632+
torch.ops.aten.conv3d.padding,
630633
]
631634

632635

0 commit comments

Comments
 (0)