@@ -56,13 +56,25 @@ class ConvPackedParam(QNNParam):
56
56
"""
57
57
58
58
def __init__ (
59
- self , weight_np , bias , scale , zero_point , param_name , stride , padding , dilation , groups
59
+ self ,
60
+ weight_np ,
61
+ bias ,
62
+ scale ,
63
+ zero_point ,
64
+ param_name ,
65
+ stride ,
66
+ padding ,
67
+ dilation ,
68
+ groups ,
69
+ output_padding ,
60
70
):
61
71
super ().__init__ (weight_np , bias , scale , zero_point , param_name )
62
72
self .stride = stride
63
73
self .padding = padding
64
74
self .dilation = dilation
65
75
self .groups = groups
76
+ # Used only for conv_transpose2d
77
+ self .output_padding = output_padding
66
78
67
79
68
80
def _get_quant_params (qweight ):
@@ -92,8 +104,18 @@ def make_conv_packed_param(param_name, qweight, bias, packed_params):
92
104
padding = packed_params .padding ()
93
105
dilation = packed_params .dilation ()
94
106
groups = packed_params .groups ()
107
+ output_padding = packed_params .output_padding ()
95
108
return ConvPackedParam (
96
- weight_np , bias , scale , zero_point , param_name , stride , padding , dilation , groups
109
+ weight_np ,
110
+ bias ,
111
+ scale ,
112
+ zero_point ,
113
+ param_name ,
114
+ stride ,
115
+ padding ,
116
+ dilation ,
117
+ groups ,
118
+ output_padding ,
97
119
)
98
120
99
121
@@ -154,7 +176,13 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
154
176
params = [qweight , qparam .scale , qparam .zero_point , qparam .bias_var ]
155
177
156
178
if isinstance (quant_params [packed_param_name ], ConvPackedParam ):
157
- params += [qparam .stride , qparam .padding , qparam .dilation , qparam .groups ]
179
+ params += [
180
+ qparam .stride ,
181
+ qparam .padding ,
182
+ qparam .dilation ,
183
+ qparam .groups ,
184
+ qparam .output_padding ,
185
+ ]
158
186
159
187
outputs [node_name ] = params
160
188
@@ -192,6 +220,7 @@ def _get_quant_param_for_input(input_value):
192
220
"quantized::mul_scalar" : (2 , 3 ),
193
221
"quantized::add_scalar" : (2 , 3 ),
194
222
"quantized::hardswish" : (1 , 2 ),
223
+ "quantized::conv_transpose2d" : qconv_indices ,
195
224
}
196
225
197
226
def dfs (current_node ):
@@ -362,6 +391,7 @@ def add_input_quant_params_to_op_inputs(graph):
362
391
"quantized::relu6" : 1 ,
363
392
"quantized::hardswish" : 1 ,
364
393
"aten::hardsigmoid" : 1 ,
394
+ "quantized::conv_transpose2d" : 1 ,
365
395
}
366
396
367
397
need_input_quant_param = set (num_quantized_inputs .keys ())
@@ -924,6 +954,65 @@ def _impl(inputs, _):
924
954
return _impl
925
955
926
956
957
+ def _quantized_conv_transpose2d (with_relu = False ):
958
+ def _impl (inputs , _ ):
959
+ # Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp
960
+ # Supported in Torch 1.7 or newer
961
+ conv_params = inputs [1 ]
962
+ weight = conv_params [0 ]
963
+ weight_scale = conv_params [1 ]
964
+ weight_zero_point = conv_params [2 ]
965
+ bias = conv_params [3 ]
966
+
967
+ strides = conv_params [4 ]
968
+ padding = conv_params [5 ]
969
+ dilation = conv_params [6 ]
970
+ groups = conv_params [7 ]
971
+ output_padding = conv_params [8 ]
972
+
973
+ output_scale = _expr .const (inputs [2 ])
974
+ output_zero_point = _expr .const (inputs [3 ])
975
+
976
+ assert len (inputs ) == 6 , "Input quant params not found in op inputs"
977
+
978
+ # These are manually added by add_input_quant_params_to_op_inputs above
979
+ # In torch, they are retrieved from QTensor data structure at runtime
980
+ input_scale = _expr .const (inputs [4 ])
981
+ input_zero_point = _expr .const (inputs [5 ])
982
+
983
+ weight_shape = list (infer_shape (weight ))
984
+
985
+ # Swap I and O dims to match shape relay expects for OIHW
986
+ weight_shape [0 ], weight_shape [1 ] = weight_shape [1 ], weight_shape [0 ]
987
+
988
+ kernel_size = (weight_shape [2 ], weight_shape [3 ])
989
+ out_channels = weight_shape [0 ]
990
+
991
+ conv_out = relay .qnn .op .conv2d_transpose (
992
+ inputs [0 ],
993
+ weight ,
994
+ input_zero_point ,
995
+ weight_zero_point ,
996
+ input_scale ,
997
+ weight_scale ,
998
+ kernel_size = kernel_size ,
999
+ dilation = dilation ,
1000
+ strides = strides ,
1001
+ padding = padding ,
1002
+ groups = groups ,
1003
+ channels = out_channels ,
1004
+ output_padding = output_padding ,
1005
+ out_dtype = "int32" ,
1006
+ kernel_layout = "OIHW" ,
1007
+ )
1008
+
1009
+ return _do_bias_and_requantize (
1010
+ conv_out , bias , input_scale , weight_scale , output_scale , output_zero_point , with_relu
1011
+ )
1012
+
1013
+ return _impl
1014
+
1015
+
927
1016
convert_map = {
928
1017
"aten::quantize_per_tensor" : _quantize_per_tensor (),
929
1018
"quantized::conv2d_relu" : _quantized_conv2d (with_relu = True ),
@@ -941,4 +1030,5 @@ def _impl(inputs, _):
941
1030
"quantized::relu6" : _relu6 (),
942
1031
"quantized::linear_dynamic" : _linear_dynamic (),
943
1032
"quantized::hardswish" : _hswish (),
1033
+ "quantized::conv_transpose2d" : _quantized_conv_transpose2d (),
944
1034
}
0 commit comments