Skip to content

Commit

Permalink
dw conv2d properly supported for wgrad
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 6, 2022
1 parent 2191918 commit 041c094
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 19 deletions.
14 changes: 12 additions & 2 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,20 @@ def conv_backward_filter(
x.shape[0], tvm.tir.expr.IntImm
), "Dynamic batch is not supported for cudnn conv2d backwad filter yet."

ic_ind = 1 if tensor_format == 0 else 3

if groups > 1:
assert (
x_shape[ic_ind] == dy.shape[ic_ind] and x_shape[ic_ind] == groups
), "Only depthwise wgrad supported for groups > 1."
ic = 1
else:
ic = x_shape[ic_ind]

if tensor_format == 0:
dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w]
dw_shape = [dy.shape[1], ic, filter_h, filter_w]
else:
dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]]
dw_shape = [dy.shape[3], filter_h, filter_w, ic]

algo = conv_backward_filter_find_algo(
tensor_format,
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ def conv2d_backward_weight_cudnn(
):
"""Compute conv2d wgrad using CuDNN library"""
assert layout in ["NCHW", "NHWC"]
# cuDNN does not seem to support other combination.
assert output_dtype == "float16", "Only supports fp16 output for cuDNN wgrad."
conv_dtype = "float32"

if dy.dtype == "float16":
# cuDNN does not seem to support other combination.
assert output_dtype == "float16", "Only supports fp16 output for cuDNN fp16 wgrad."

conv_dtype = "float32" # Accumulation is always fp32
return cudnn.conv_backward_filter(
dy,
x,
Expand Down
43 changes: 37 additions & 6 deletions python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
def conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW layout.
Parameters
Expand Down Expand Up @@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
R, S = kernel_size
pad_h, pad_w = padding
stride_h, stride_w = stride
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)
is_depth_wise = C == K and C == groups

if is_depth_wise:
assert channels == groups, "Only channel_mult == 1 supported for now."
dw = np.zeros((K, 1, R, S)).astype(dy_np.dtype)
else:
assert groups == 1, "General grouped conv2d not supported for now."
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)

for k in range(K):
for r in range(R):
for s in range(S):
for c in range(C):
for c in range(dw.shape[1]):
acc = 0
for n in range(N):
for p in range(P):
for q in range(Q):
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)
if not is_depth_wise:
in_c = c
else:
in_c = k

coord = (
n,
in_c,
p * stride_h - pad_h + r,
q * stride_w - pad_w + s,
)

if (
coord[2] < H
Expand All @@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
return dw


def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
def conv2d_backward_weight_python(
dy_np, x_np, kernel_size, stride, padding, layout="NCHW", groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.
Parameters
Expand All @@ -99,20 +120,30 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay
layout: string
Layout of dy_np and x_np
groups: int
Number of groups for grouped convolution.
channels : int
Number of output channels of this convolution.
Returns
-------
dw_np : np.ndarray
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
"""
if layout == "NCHW":
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)
return conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups, channels
)

dw_np_oihw = conv2d_backward_weight_nchw_python(
np.transpose(dy_np, [0, 3, 1, 2]),
np.transpose(x_np, [0, 3, 1, 2]),
kernel_size,
stride,
padding,
groups,
channels,
)
return np.transpose(dw_np_oihw, [0, 2, 3, 1])
23 changes: 19 additions & 4 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,26 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att
auto in_channels = dshape_nchw[1];
auto out_channels = grad_shape_nchw[1];

Array<IndexExpr> wshape_oihw(
{out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]});

auto in_channels_intimm = in_channels.as<IntImmNode>();
auto out_channels_intimm = out_channels.as<IntImmNode>();
ICHECK(in_channels_intimm);
ICHECK(out_channels_intimm);

IndexExpr weight_dim_i;
if (in_channels_intimm->value == out_channels_intimm->value &&
in_channels_intimm->value == param->groups) {
// depthwise
ICHECK(param->channels.defined()) << "out_channels attribute not specified for depth wise conv2d.";
weight_dim_i = indexdiv(param->channels, param->groups);
} else {
weight_dim_i = indexdiv(in_channels, param->groups);
}

Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0], param->kernel_size[1]};
auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw);
reporter->Assign(types[2], TensorType(wshape, data->dtype));

const auto dw_dtype = param->out_dtype == DataType() ? grad->dtype : param->out_dtype;
reporter->Assign(types[2], TensorType(wshape, dw_dtype));
return true;
}

Expand Down
20 changes: 16 additions & 4 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,24 @@ def test_batch_flatten_grad():
verify_batch_flatten_grad((1, 8))


def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding):
def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding, groups=1, out_channels=None):
dtype = "float32"
dy = relay.var("dy", shape=dy_shape, dtype=dtype)
x = relay.var("x", shape=x_shape, dtype=dtype)
dw_func = relay.Function(
[dy, x],
relay.nn.conv2d_backward_weight(
dy, x, strides=stride, padding=padding, kernel_size=kernel_size
dy,
x,
strides=stride,
padding=padding,
kernel_size=kernel_size,
groups=groups,
channels=out_channels,
out_dtype=dtype,
),
)

dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize())

for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]:
Expand All @@ -251,7 +259,7 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin

dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy()
ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python(
dy_np, x_np, kernel_size, stride, padding
dy_np, x_np, kernel_size, stride, padding, groups=groups, channels=out_channels
)

np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4)
Expand All @@ -260,7 +268,11 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin
def test_conv2d_backward_weight():
verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1))
verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0))
verify_conv2d_backward_weight(
(1, 16, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=16
)


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_conv2d_backward_weight()

0 comments on commit 041c094

Please sign in to comment.