diff --git a/python/tvm/topi/cuda/conv2d_transpose.py b/python/tvm/topi/cuda/conv2d_transpose.py index 3d308474bc50b..6a614912a969b 100644 --- a/python/tvm/topi/cuda/conv2d_transpose.py +++ b/python/tvm/topi/cuda/conv2d_transpose.py @@ -59,7 +59,10 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_ stride_height, stride_width = stride outpad_height, outpad_width = output_padding assert outpad_height < stride_height and outpad_width < stride_width - assert inp_channels % groups == 0, f"input channels {inp_channels} must divide group size {groups}" + + if inp_channels % groups != 0: + raise ValueError(f"input channels {inp_channels} must divide group size {groups}") + cfg.stride = stride pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple( padding, (kernel_height, kernel_width) @@ -112,14 +115,14 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_ data_out = te.compute( (batch, out_channels * groups, out_height, out_width), lambda b, c, h, w: te.sum( - data[ - b, c // out_channels * (inp_channels // groups) + dc, h + dh, w + dw - ].astype(out_dtype) + data[b, c // out_channels * (inp_channels // groups) + dc, h + dh, w + dw].astype( + out_dtype + ) * kernel[ c // out_channels * (inp_channels // groups) + dc, c % out_channels, kernel_height - 1 - dh, - kernel_width - 1 - dw + kernel_width - 1 - dw, ].astype(out_dtype), axis=[dc, dh, dw], ),