Skip to content

Commit

Permalink
fix conv2d type rel for depth wise and grouped conv2d
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 96416c4 commit 2167c25
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 16 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
channels=attrs.channels,
)

# infer shape of backward_weight
Expand Down
43 changes: 37 additions & 6 deletions python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
def conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW layout.
Parameters
Expand Down Expand Up @@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
R, S = kernel_size
pad_h, pad_w = padding
stride_h, stride_w = stride
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)
is_depth_wise = C == K and C == groups

if is_depth_wise:
channel_mult = channels // groups
dw = np.zeros((K, channel_mult, R, S)).astype(dy_np.dtype)
else:
dw = np.zeros((K, C // groups, R, S)).astype(dy_np.dtype)
channel_mult = 1

for k in range(K):
for r in range(R):
for s in range(S):
for c in range(C):
for c in range(dw.shape[1]):
acc = 0
for n in range(N):
for p in range(P):
for q in range(Q):
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)
if not is_depth_wise:
in_c = c
else:
in_c = k // channel_mult

coord = (
n,
in_c,
p * stride_h - pad_h + r,
q * stride_w - pad_w + s,
)

if (
coord[2] < H
Expand All @@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
return dw


def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
def conv2d_backward_weight_python(
dy_np, x_np, kernel_size, stride, padding, layout="NCHW", groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.
Parameters
Expand All @@ -99,20 +120,30 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay
layout: string
Layout of dy_np and x_np
groups: int
Number of groups for grouped convolution.
channels : int
Number of output channels of this convolution.
Returns
-------
dw_np : np.ndarray
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
"""
if layout == "NCHW":
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)
return conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups, channels
)

dw_np_oihw = conv2d_backward_weight_nchw_python(
np.transpose(dy_np, [0, 3, 1, 2]),
np.transpose(x_np, [0, 3, 1, 2]),
kernel_size,
stride,
padding,
groups,
channels,
)
return np.transpose(dw_np_oihw, [0, 2, 3, 1])
21 changes: 17 additions & 4 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,23 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att

auto in_channels = dshape_nchw[1];
auto out_channels = grad_shape_nchw[1];

Array<IndexExpr> wshape_oihw(
{out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]});

auto in_channels_intimm = in_channels.as<IntImmNode>();
auto out_channels_intimm = out_channels.as<IntImmNode>();
ICHECK(in_channels_intimm);
ICHECK(out_channels_intimm);

IndexExpr weight_dim_i;
if (in_channels_intimm->value == out_channels_intimm->value &&
in_channels_intimm->value == param->groups) {
// depthwise
ICHECK(param->channels.defined()) << "out_channels attribute not specified for depth wise conv2d.";
weight_dim_i = indexdiv(param->channels, param->groups);
} else {
weight_dim_i = indexdiv(in_channels, param->groups);
}

Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0],
param->kernel_size[1]};
auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw);
reporter->Assign(types[2], TensorType(wshape, data->dtype));
return true;
Expand Down
27 changes: 21 additions & 6 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,26 @@ def test_batch_flatten_grad():
verify_batch_flatten_grad((1, 8))


def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding):
def verify_conv2d_backward_weight(
dy_shape, x_shape, kernel_size, stride, padding, groups=1, out_channels=None
):
dtype = "float32"
dy = relay.var("dy", shape=dy_shape, dtype=dtype)
x = relay.var("x", shape=x_shape, dtype=dtype)
dw_func = relay.Function(
[dy, x],
relay.nn.conv2d_backward_weight(
dy, x, strides=stride, padding=padding, kernel_size=kernel_size
dy,
x,
strides=stride,
padding=padding,
kernel_size=kernel_size,
groups=groups,
channels=out_channels,
),
)
dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize())
print(run_infer_type(dw_func_legalized))

for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]:
if "cudnn" in target and not tvm.contrib.cudnn.exists():
Expand All @@ -251,16 +260,22 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin

dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy()
ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python(
dy_np, x_np, kernel_size, stride, padding
dy_np, x_np, kernel_size, stride, padding, groups=groups, channels=out_channels
)

np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4)


def test_conv2d_backward_weight():
verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1))
verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0))
# verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1))
# verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0))
verify_conv2d_backward_weight((1, 16, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=16)
# verify_conv2d_backward_weight(
# (1, 32, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=8
# )
# verify_conv2d_backward_weight((1, 32, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=32)


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_conv2d_backward_weight()

0 comments on commit 2167c25

Please sign in to comment.