Skip to content

Commit

Permalink
fixed conv2d_backward_weight typerel for dw conv2d
Browse files Browse the repository at this point in the history
commit 16fe531
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Feb 3 12:59:22 2022 +0900

    wip

commit 2167c25
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Feb 3 04:22:19 2022 +0900

    fix conv2d type rel for depth wise and grouped conv2d
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 14b12e5 commit ae09b0f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 13 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
dilation=attrs.strides,
groups=in_channel * batch,
out_dtype=attrs.out_dtype,
channels=attrs.channels,
)

# infer shape of backward_weight
Expand Down
43 changes: 37 additions & 6 deletions python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
def conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW layout.
Parameters
Expand Down Expand Up @@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
R, S = kernel_size
pad_h, pad_w = padding
stride_h, stride_w = stride
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)
is_depth_wise = C == K and C == groups

if is_depth_wise:
assert channels == groups, "Only channel_mult == 1 supported for now."
dw = np.zeros((K, 1, R, S)).astype(dy_np.dtype)
else:
assert groups == 1, "General grouped conv2d not supported for now."
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)

for k in range(K):
for r in range(R):
for s in range(S):
for c in range(C):
for c in range(dw.shape[1]):
acc = 0
for n in range(N):
for p in range(P):
for q in range(Q):
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)
if not is_depth_wise:
in_c = c
else:
in_c = k

coord = (
n,
in_c,
p * stride_h - pad_h + r,
q * stride_w - pad_w + s,
)

if (
coord[2] < H
Expand All @@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
return dw


def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
def conv2d_backward_weight_python(
dy_np, x_np, kernel_size, stride, padding, layout="NCHW", groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.
Parameters
Expand All @@ -99,20 +120,30 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay
layout: string
Layout of dy_np and x_np
groups: int
Number of groups for grouped convolution.
channels : int
Number of output channels of this convolution.
Returns
-------
dw_np : np.ndarray
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
"""
if layout == "NCHW":
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)
return conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups, channels
)

dw_np_oihw = conv2d_backward_weight_nchw_python(
np.transpose(dy_np, [0, 3, 1, 2]),
np.transpose(x_np, [0, 3, 1, 2]),
kernel_size,
stride,
padding,
groups,
channels,
)
return np.transpose(dw_np_oihw, [0, 2, 3, 1])
21 changes: 17 additions & 4 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,23 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att

auto in_channels = dshape_nchw[1];
auto out_channels = grad_shape_nchw[1];

Array<IndexExpr> wshape_oihw(
{out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]});

auto in_channels_intimm = in_channels.as<IntImmNode>();
auto out_channels_intimm = out_channels.as<IntImmNode>();
ICHECK(in_channels_intimm);
ICHECK(out_channels_intimm);

IndexExpr weight_dim_i;
if (in_channels_intimm->value == out_channels_intimm->value &&
in_channels_intimm->value == param->groups) {
// depthwise
ICHECK(param->channels.defined()) << "out_channels attribute not specified for depth wise conv2d.";
weight_dim_i = indexdiv(param->channels, param->groups);
} else {
weight_dim_i = indexdiv(in_channels, param->groups);
}

Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0],
param->kernel_size[1]};
auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw);
reporter->Assign(types[2], TensorType(wshape, param->out_dtype));
return true;
Expand Down
17 changes: 14 additions & 3 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,22 @@ def test_batch_flatten_grad():
verify_batch_flatten_grad((1, 8))


def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding):
def verify_conv2d_backward_weight(
dy_shape, x_shape, kernel_size, stride, padding, groups=1, out_channels=None
):
dtype = "float32"
dy = relay.var("dy", shape=dy_shape, dtype=dtype)
x = relay.var("x", shape=x_shape, dtype=dtype)
dw_func = relay.Function(
[dy, x],
relay.nn.conv2d_backward_weight(
dy, x, strides=stride, padding=padding, kernel_size=kernel_size
dy,
x,
strides=stride,
padding=padding,
kernel_size=kernel_size,
groups=groups,
channels=out_channels,
),
)
dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize())
Expand All @@ -251,7 +259,7 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin

dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy()
ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python(
dy_np, x_np, kernel_size, stride, padding
dy_np, x_np, kernel_size, stride, padding, groups=groups, channels=out_channels
)

np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4)
Expand All @@ -260,6 +268,9 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin
def test_conv2d_backward_weight():
verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1))
verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0))
verify_conv2d_backward_weight(
(1, 16, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=16
)


if __name__ == "__main__":
Expand Down

0 comments on commit ae09b0f

Please sign in to comment.