Skip to content

Commit ae0da6b

Browse files
committed
[Frontend][PyTorch] support for quantized conv_transpose2d op
PyTorch uses the same underlying function to pack and unpack the params for conv2d and conv_transpose2d ops. This change adds support for quantized conv_transpose2d op by reusing the ConvPackedParam and adding the output_padding param to it. This output_padding param will remain unused in case of conv2d. Also added test for above with specific condition for torch v1.7.1 and below.
1 parent 4251103 commit ae0da6b

File tree

2 files changed

+118
-4
lines changed

2 files changed

+118
-4
lines changed

python/tvm/relay/frontend/qnn_torch.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,25 @@ class ConvPackedParam(QNNParam):
5656
"""
5757

5858
def __init__(
59-
self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
59+
self,
60+
weight_np,
61+
bias,
62+
scale,
63+
zero_point,
64+
param_name,
65+
stride,
66+
padding,
67+
dilation,
68+
groups,
69+
output_padding,
6070
):
6171
super().__init__(weight_np, bias, scale, zero_point, param_name)
6272
self.stride = stride
6373
self.padding = padding
6474
self.dilation = dilation
6575
self.groups = groups
76+
# Used only for conv_transpose2d
77+
self.output_padding = output_padding
6678

6779

6880
def _get_quant_params(qweight):
@@ -92,8 +104,18 @@ def make_conv_packed_param(param_name, qweight, bias, packed_params):
92104
padding = packed_params.padding()
93105
dilation = packed_params.dilation()
94106
groups = packed_params.groups()
107+
output_padding = packed_params.output_padding()
95108
return ConvPackedParam(
96-
weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
109+
weight_np,
110+
bias,
111+
scale,
112+
zero_point,
113+
param_name,
114+
stride,
115+
padding,
116+
dilation,
117+
groups,
118+
output_padding,
97119
)
98120

99121

@@ -154,7 +176,13 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
154176
params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var]
155177

156178
if isinstance(quant_params[packed_param_name], ConvPackedParam):
157-
params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups]
179+
params += [
180+
qparam.stride,
181+
qparam.padding,
182+
qparam.dilation,
183+
qparam.groups,
184+
qparam.output_padding,
185+
]
158186

159187
outputs[node_name] = params
160188

@@ -192,6 +220,7 @@ def _get_quant_param_for_input(input_value):
192220
"quantized::mul_scalar": (2, 3),
193221
"quantized::add_scalar": (2, 3),
194222
"quantized::hardswish": (1, 2),
223+
"quantized::conv_transpose2d": qconv_indices,
195224
}
196225

197226
def dfs(current_node):
@@ -362,6 +391,7 @@ def add_input_quant_params_to_op_inputs(graph):
362391
"quantized::relu6": 1,
363392
"quantized::hardswish": 1,
364393
"aten::hardsigmoid": 1,
394+
"quantized::conv_transpose2d": 1,
365395
}
366396

367397
need_input_quant_param = set(num_quantized_inputs.keys())
@@ -924,6 +954,65 @@ def _impl(inputs, _):
924954
return _impl
925955

926956

957+
def _quantized_conv_transpose2d(with_relu=False):
958+
def _impl(inputs, _):
959+
# Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp
960+
# Supported in Torch 1.7 or newer
961+
conv_params = inputs[1]
962+
weight = conv_params[0]
963+
weight_scale = conv_params[1]
964+
weight_zero_point = conv_params[2]
965+
bias = conv_params[3]
966+
967+
strides = conv_params[4]
968+
padding = conv_params[5]
969+
dilation = conv_params[6]
970+
groups = conv_params[7]
971+
output_padding = conv_params[8]
972+
973+
output_scale = _expr.const(inputs[2])
974+
output_zero_point = _expr.const(inputs[3])
975+
976+
assert len(inputs) == 6, "Input quant params not found in op inputs"
977+
978+
# These are manually added by add_input_quant_params_to_op_inputs above
979+
# In torch, they are retrieved from QTensor data structure at runtime
980+
input_scale = _expr.const(inputs[4])
981+
input_zero_point = _expr.const(inputs[5])
982+
983+
weight_shape = list(infer_shape(weight))
984+
985+
# Swap I and O dims to match shape relay expects for OIHW
986+
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]
987+
988+
kernel_size = (weight_shape[2], weight_shape[3])
989+
out_channels = weight_shape[0]
990+
991+
conv_out = relay.qnn.op.conv2d_transpose(
992+
inputs[0],
993+
weight,
994+
input_zero_point,
995+
weight_zero_point,
996+
input_scale,
997+
weight_scale,
998+
kernel_size=kernel_size,
999+
dilation=dilation,
1000+
strides=strides,
1001+
padding=padding,
1002+
groups=groups,
1003+
channels=out_channels,
1004+
output_padding=output_padding,
1005+
out_dtype="int32",
1006+
kernel_layout="OIHW",
1007+
)
1008+
1009+
return _do_bias_and_requantize(
1010+
conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu
1011+
)
1012+
1013+
return _impl
1014+
1015+
9271016
convert_map = {
9281017
"aten::quantize_per_tensor": _quantize_per_tensor(),
9291018
"quantized::conv2d_relu": _quantized_conv2d(with_relu=True),
@@ -941,4 +1030,5 @@ def _impl(inputs, _):
9411030
"quantized::relu6": _relu6(),
9421031
"quantized::linear_dynamic": _linear_dynamic(),
9431032
"quantized::hardswish": _hswish(),
1033+
"quantized::conv_transpose2d": _quantized_conv_transpose2d(),
9441034
}

tests/python/frontend/pytorch/qnn_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,20 @@ def fuse_model(self):
9292
fuse_modules(self.conv, indices, inplace=True)
9393

9494

95+
class ConvTranspose(nn.Module):
96+
def __init__(self):
97+
super().__init__()
98+
layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)]
99+
self.conv = nn.Sequential(*layers)
100+
self.quant_wrap = QuantWrapper(self.conv)
101+
102+
def forward(self, x):
103+
return self.quant_wrap(x)
104+
105+
def fuse_model(self):
106+
pass
107+
108+
95109
class Linear(nn.Module):
96110
def __init__(self, with_relu=False):
97111
super().__init__()
@@ -270,6 +284,7 @@ def test_quantized_modules():
270284
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
271285
("linear" + postfix, (16, 16), Linear(), per_channel),
272286
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
287+
("conv_transpose", imagenet_ishape, ConvTranspose(), False),
273288
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
274289
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
275290
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
@@ -281,7 +296,15 @@ def test_quantized_modules():
281296
raw_module.eval()
282297
inp = torch.rand(ishape)
283298

284-
quantize_model(raw_module, inp, per_channel=per_channel)
299+
# quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0.
300+
if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"):
301+
prev_engine = torch.backends.quantized.engine
302+
torch.backends.quantized.engine = "qnnpack"
303+
quantize_model(raw_module, inp, per_channel=per_channel)
304+
torch.backends.quantized.engine = prev_engine
305+
else:
306+
quantize_model(raw_module, inp, per_channel=per_channel)
307+
285308
script_module = torch.jit.trace(raw_module, inp).eval()
286309

287310
with torch.no_grad():
@@ -308,6 +331,7 @@ def test_quantized_modules():
308331
conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
309332
linear 0.15987062 0.009231662 0.794921875
310333
linear_relu 0.14180502 0.0053220326 0.8828125
334+
conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806
311335
conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
312336
conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
313337
linear, per_channel 0.0 0.0 1.0

0 commit comments

Comments
 (0)