Skip to content

Commit fd5915a

Browse files
authored
[Relay] Add conv2d_backward_weight op (without topi) (#9954)
* python plumbing * add cpp def * legalize worked * clean up * layout conversion doesnt work * extract wgrad body * fix convert layout * black * fix kernel size * revert irrelevant change * add doc, clarify the meanings of parameters * update layout convert * test passed * fixed layout conversion * update convert layout * remove print * remove layout convert for now * minor fix * removed unused import * add wgrad python reference * add test stub * add doc * test other stride and pad * tweak * more pylint filter * fix typo in doc * swap arg order (data, grad) to be consistent with conv2d_transpose(dgrad)
1 parent a9b1d5b commit fd5915a

File tree

8 files changed

+344
-49
lines changed

8 files changed

+344
-49
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
reshape_like,
5353
strided_slice,
5454
take,
55-
tile,
5655
transpose,
5756
where,
5857
repeat,
@@ -399,15 +398,14 @@ def conv2d_grad(orig, grad):
399398
data_shape = get_const_tuple(data.checked_type.shape)
400399
weight_shape = get_const_tuple(weight.checked_type.shape)
401400
_, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape)
402-
batch, in_channel, in_h, in_w = data_shape
403-
out_channel, _, filter_h, filter_w = weight_shape
401+
_, _, in_h, in_w = data_shape
402+
_, _, filter_h, filter_w = weight_shape
404403

405404
# infer output_padding
406405
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
407406
get_const_tuple(attrs.padding), (filter_h, filter_w)
408407
)
409408
stride_h, stride_w = get_const_tuple(attrs.strides)
410-
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
411409
out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
412410
out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
413411
output_padding = (in_h - out_h, in_w - out_w)
@@ -425,46 +423,21 @@ def conv2d_grad(orig, grad):
425423
groups=attrs.groups,
426424
output_padding=output_padding,
427425
)
428-
grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
429-
grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
430-
data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw
431426

432-
backward_weight = _nn.conv2d(
433-
data,
427+
backward_weight = _nn.conv2d_backward_weight(
434428
grad,
435-
strides=attrs.dilation,
429+
data,
430+
strides=attrs.strides,
436431
padding=attrs.padding,
437-
dilation=attrs.strides,
438-
groups=in_channel * batch,
439-
)
440-
# infer shape of backward_weight
441-
padded_weight_grad_h = (
442-
in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
443-
) // dilation_h + 1
444-
padded_weight_grad_w = (
445-
in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
446-
) // dilation_w + 1
447-
backward_weight = reshape(
448-
backward_weight,
449-
[
450-
batch,
451-
in_channel // attrs.groups,
452-
out_channel,
453-
padded_weight_grad_h,
454-
padded_weight_grad_w,
455-
],
432+
dilation=attrs.dilation,
433+
groups=attrs.groups,
434+
channels=attrs.channels,
435+
kernel_size=(filter_h, filter_w),
436+
grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout,
437+
data_layout=attrs.data_layout,
438+
kernel_layout=attrs.kernel_layout,
439+
out_dtype=attrs.out_dtype,
456440
)
457-
backward_weight = _sum(backward_weight, axis=0)
458-
backward_weight = transpose(backward_weight, [1, 0, 2, 3])
459-
460-
assert padded_weight_grad_h >= filter_h
461-
assert padded_weight_grad_w >= filter_w
462-
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
463-
backward_weight = strided_slice(
464-
backward_weight,
465-
begin=[0, 0, 0, 0],
466-
end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
467-
)
468441

469442
return [backward_data, backward_weight]
470443

python/tvm/relay/op/nn/_nn.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tvm.runtime import convert
2424
from tvm.te.hybrid import script
2525
from tvm.topi.utils import get_const_tuple
26+
from tvm.topi.nn.utils import get_pad_tuple
2627

2728
from ....ir import container
2829
from ....tir import expr
@@ -1061,6 +1062,83 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
10611062
reg.register_injective_schedule("nn.batch_to_space_nd")
10621063

10631064

1065+
@reg.register_legalize("nn.conv2d_backward_weight")
1066+
def legalize_conv2d_backward_weight(attrs, inputs, types):
1067+
"""Legalize conv2d_backward_weight op.
1068+
1069+
Parameters
1070+
----------
1071+
attrs : tvm.ir.Attrs
1072+
Attributes of current op
1073+
inputs : list of tvm.relay.Expr
1074+
The args of the Relay expr to be legalized
1075+
types : list of types
1076+
List of input and output types
1077+
1078+
Returns
1079+
-------
1080+
result : tvm.relay.Expr
1081+
The legalized expr
1082+
"""
1083+
grad, data = inputs
1084+
data_shape = get_const_tuple(data.checked_type.shape)
1085+
weight_shape = get_const_tuple(types[2].shape)
1086+
_, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape)
1087+
batch, in_channel, in_h, in_w = data_shape
1088+
_, _, filter_h, filter_w = weight_shape
1089+
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
1090+
get_const_tuple(attrs.padding), (filter_h, filter_w)
1091+
)
1092+
stride_h, stride_w = get_const_tuple(attrs.strides)
1093+
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
1094+
1095+
grad = relay.tile(grad, [1, in_channel // attrs.groups, 1, 1])
1096+
grad = relay.reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
1097+
data = relay.reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw
1098+
1099+
backward_weight = relay.nn.conv2d(
1100+
data,
1101+
grad,
1102+
strides=attrs.dilation,
1103+
padding=attrs.padding,
1104+
dilation=attrs.strides,
1105+
groups=in_channel * batch,
1106+
)
1107+
1108+
# infer shape of backward_weight
1109+
padded_weight_grad_h = (
1110+
in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
1111+
) // dilation_h + 1
1112+
padded_weight_grad_w = (
1113+
in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
1114+
) // dilation_w + 1
1115+
1116+
backward_weight = relay.reshape(
1117+
backward_weight,
1118+
[
1119+
batch,
1120+
in_channel // attrs.groups,
1121+
out_channel,
1122+
padded_weight_grad_h,
1123+
padded_weight_grad_w,
1124+
],
1125+
)
1126+
backward_weight = relay.sum(backward_weight, axis=0)
1127+
backward_weight = relay.transpose(backward_weight, [1, 0, 2, 3])
1128+
1129+
assert padded_weight_grad_h >= filter_h
1130+
assert padded_weight_grad_w >= filter_w
1131+
1132+
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
1133+
backward_weight = relay.strided_slice(
1134+
backward_weight,
1135+
begin=[0, 0, 0, 0],
1136+
end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
1137+
)
1138+
1139+
return backward_weight
1140+
1141+
10641142
#####################
10651143
# Shape functions #
10661144
#####################

python/tvm/relay/op/nn/nn.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3770,3 +3770,54 @@ def batch_to_space_nd(data, block_shape, crops):
37703770
"""
37713771

37723772
return _make.batch_to_space_nd(data, block_shape, crops)
3773+
3774+
3775+
def conv2d_backward_weight(
3776+
grad,
3777+
data,
3778+
strides=(1, 1),
3779+
padding=(0, 0),
3780+
dilation=(1, 1),
3781+
groups=1,
3782+
channels=None,
3783+
kernel_size=None,
3784+
grad_layout="NCHW",
3785+
data_layout="NCHW",
3786+
kernel_layout="OIHW",
3787+
out_dtype="",
3788+
):
3789+
r"""The gradient of conv2d with respect to weight.
3790+
3791+
This operator takes the output gradient `grad` and convolves it with `data` as
3792+
the convolution kernel, to produce the gradient with respect to weight.
3793+
3794+
Note that the parameter `kernel_size` is the spatial size of the corresponding
3795+
forward convolution kernel, not that of `data`. `grad_layout` and
3796+
`kernel_layout` are the layouts of `grad` and the weight gradient respectively.
3797+
3798+
Other parameters are the same as the conv2d op. See its documentation for more
3799+
details.
3800+
3801+
"""
3802+
if isinstance(kernel_size, int):
3803+
kernel_size = (kernel_size, kernel_size)
3804+
if isinstance(strides, int):
3805+
strides = (strides, strides)
3806+
if isinstance(dilation, int):
3807+
dilation = (dilation, dilation)
3808+
padding = get_pad_tuple2d(padding)
3809+
3810+
return _make.conv2d_backward_weight(
3811+
grad,
3812+
data,
3813+
strides,
3814+
padding,
3815+
dilation,
3816+
groups,
3817+
channels,
3818+
kernel_size,
3819+
grad_layout,
3820+
data_layout,
3821+
kernel_layout,
3822+
out_dtype,
3823+
)

python/tvm/relay/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def check_grad(
127127

128128
fwd_func = run_infer_type(func)
129129
bwd_func = run_infer_type(gradient(fwd_func, mode=mode))
130+
bwd_func = run_opt_pass(bwd_func, relay.transform.Legalize())
130131

131132
if scale is None:
132133
scale = 10 * eps

python/tvm/topi/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@
7575
from .nll_loss import nll_loss
7676
from .dense import dense
7777
from .searchsorted import searchsorted_ref
78+
from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, too-many-nested-blocks
18+
"""Gradient of conv2d with respect to weight in python"""
19+
import numpy as np
20+
21+
22+
# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
23+
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
24+
"""Gradient of the conv2d op with respect to weight, in NCHW layout.
25+
26+
Parameters
27+
----------
28+
dy_np : numpy.ndarray
29+
4-D with shape [batch, in_channel, out_height, out_width]
30+
31+
x_np : numpy.ndarray
32+
4-D with shape [batch, in_channel, in_height, in_width]
33+
34+
kernel_size : tuple of two ints
35+
Height and width of the weight
36+
37+
stride : tuple of two ints
38+
Stride size, or [stride_height, stride_width]
39+
40+
padding : tuple of two ints
41+
Spatial padding, or [pad_h, pad_w]
42+
43+
Returns
44+
-------
45+
b_np : np.ndarray
46+
4-D with shape [num_filter, in_channel, filter_height, filter_width]
47+
48+
"""
49+
N, C, H, W = x_np.shape
50+
_, K, P, Q = dy_np.shape
51+
R, S = kernel_size
52+
pad_h, pad_w = padding
53+
stride_h, stride_w = stride
54+
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)
55+
56+
for k in range(K):
57+
for r in range(R):
58+
for s in range(S):
59+
for c in range(C):
60+
acc = 0
61+
for n in range(N):
62+
for p in range(P):
63+
for q in range(Q):
64+
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)
65+
66+
if (
67+
coord[2] < H
68+
and coord[2] >= 0
69+
and coord[3] < W
70+
and coord[3] >= 0
71+
):
72+
acc += dy_np[n, k, p, q] * x_np[coord]
73+
74+
dw[k, c, r, s] = acc
75+
76+
return dw

0 commit comments

Comments
 (0)