@@ -2478,6 +2478,27 @@ def forward(self, x):
2478
2478
node_list ,
2479
2479
)
2480
2480
2481
+ example_inputs = (torch .randn (1 , 3 , 5 , 5 ),)
2482
+ node_occurrence = {
2483
+ # two for input of the first conv, one for output for the first conv
2484
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2485
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2486
+ }
2487
+ node_list = [
2488
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2489
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2490
+ torch .ops .aten .conv2d .padding ,
2491
+ torch .ops .aten .relu .default ,
2492
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2493
+ ]
2494
+ self ._test_quantizer (
2495
+ TestHelperModules .ConvWithBNRelu (dim = 2 , relu = True , bn = True , padding = "same" ),
2496
+ example_inputs ,
2497
+ BackendAQuantizer (),
2498
+ node_occurrence ,
2499
+ node_list ,
2500
+ )
2501
+
2481
2502
def test_conv_transpose3d_bn_relu (self ):
2482
2503
class BackendAQuantizer (Quantizer ):
2483
2504
def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
@@ -2549,27 +2570,124 @@ def __init__(self):
2549
2570
def forward (self , x ):
2550
2571
return torch .nn .functional .relu (self .bn (self .conv_t (x )))
2551
2572
2552
- example_inputs = (torch .randn (1 , 2 , 2 , 5 , 5 ),)
2553
- node_occurrence = {
2554
- # two for input of the first conv, one for output for the first conv
2555
- torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2556
- torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2557
- }
2558
- node_list = [
2559
- torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2560
- torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2561
- torch .ops .aten .conv_transpose3d .input ,
2562
- torch .ops .aten .relu .default ,
2563
- torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2573
+ def test_conv_padding_bn_relu (self ):
2574
+ class BackendAQuantizer (Quantizer ):
2575
+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2576
+ act_qspec = QuantizationSpec (
2577
+ dtype = torch .uint8 ,
2578
+ quant_min = 0 ,
2579
+ quant_max = 255 ,
2580
+ qscheme = torch .per_tensor_affine ,
2581
+ is_dynamic = False ,
2582
+ observer_or_fake_quant_ctr = observer .default_observer ,
2583
+ )
2584
+ weight_qspec = QuantizationSpec (
2585
+ dtype = torch .int8 ,
2586
+ quant_min = - 128 ,
2587
+ quant_max = 127 ,
2588
+ qscheme = torch .per_tensor_affine ,
2589
+ is_dynamic = False ,
2590
+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2591
+ )
2592
+ bias_qspec = QuantizationSpec (
2593
+ dtype = torch .float32 ,
2594
+ is_dynamic = False ,
2595
+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2596
+ )
2597
+
2598
+ for n in model .graph .nodes :
2599
+ if (
2600
+ n .op != "call_function"
2601
+ or n .target != torch .ops .aten .relu .default
2602
+ ):
2603
+ continue
2604
+ relu_node = n
2605
+ n = n .args [0 ]
2606
+
2607
+ # Check for any of the conv operations
2608
+ conv_ops = [
2609
+ torch .ops .aten .conv1d .padding ,
2610
+ torch .ops .aten .conv2d .padding ,
2611
+ torch .ops .aten .conv3d .padding
2612
+ ]
2613
+ if n .op != "call_function" or n .target not in conv_ops :
2614
+ continue
2615
+
2616
+ conv_node = n
2617
+ input_act = conv_node .args [0 ]
2618
+ weight = conv_node .args [1 ]
2619
+ bias = conv_node .args [2 ]
2620
+ conv_node .meta ["quantization_annotation" ] = (
2621
+ QuantizationAnnotation (
2622
+ input_qspec_map = {
2623
+ input_act : act_qspec ,
2624
+ weight : weight_qspec ,
2625
+ bias : bias_qspec ,
2626
+ },
2627
+ _annotated = True ,
2628
+ )
2629
+ )
2630
+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2631
+ output_qspec = act_qspec ,
2632
+ _annotated = True ,
2633
+ )
2634
+
2635
+ def validate (self , model : torch .fx .GraphModule ) -> None :
2636
+ pass
2637
+
2638
+ # Test cases for Conv1d, Conv2d, Conv3d
2639
+ test_cases = [
2640
+ {
2641
+ "conv_type" : torch .nn .Conv1d ,
2642
+ "bn_type" : torch .nn .BatchNorm1d ,
2643
+ "example_input" : (torch .randn (1 , 3 , 5 ),),
2644
+ "conv_op" : torch .ops .aten .conv1d .padding ,
2645
+ },
2646
+ {
2647
+ "conv_type" : torch .nn .Conv2d ,
2648
+ "bn_type" : torch .nn .BatchNorm2d ,
2649
+ "example_input" : (torch .randn (1 , 3 , 5 , 5 ),),
2650
+ "conv_op" : torch .ops .aten .conv2d .padding ,
2651
+ },
2652
+ {
2653
+ "conv_type" : torch .nn .Conv3d ,
2654
+ "bn_type" : torch .nn .BatchNorm3d ,
2655
+ "example_input" : (torch .randn (1 , 3 , 5 , 5 , 5 ),),
2656
+ "conv_op" : torch .ops .aten .conv3d .padding ,
2657
+ },
2564
2658
]
2565
- model = M ().eval ()
2566
- self ._test_quantizer (
2567
- model ,
2568
- example_inputs ,
2569
- BackendAQuantizer (),
2570
- node_occurrence ,
2571
- node_list ,
2572
- )
2659
+
2660
+ for test_case in test_cases :
2661
+ with self .subTest (conv_type = test_case ["conv_type" ].__name__ ):
2662
+ class M (torch .nn .Module ):
2663
+ def __init__ (self ):
2664
+ super ().__init__ ()
2665
+ self .conv = test_case ["conv_type" ](3 , 3 , 3 , padding = "same" )
2666
+ self .bn = test_case ["bn_type" ](3 )
2667
+
2668
+ def forward (self , x ):
2669
+ return torch .nn .functional .relu (self .bn (self .conv (x )))
2670
+
2671
+ node_occurrence = {
2672
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2673
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2674
+ }
2675
+ node_list = [
2676
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2677
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2678
+ test_case ["conv_op" ],
2679
+ torch .ops .aten .relu .default ,
2680
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2681
+ ]
2682
+
2683
+ model = M ().eval ()
2684
+ self ._test_quantizer (
2685
+ model ,
2686
+ test_case ["example_input" ],
2687
+ BackendAQuantizer (),
2688
+ node_occurrence ,
2689
+ node_list ,
2690
+ )
2573
2691
2574
2692
def test_multi_users_without_output_observer (self ):
2575
2693
"""
0 commit comments