Skip to content

Commit

Permalink
[LINT] Hotfix pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Mar 2, 2022
1 parent 4ffb2e8 commit 14f29df
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions python/tvm/topi/cuda/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_
stride_height, stride_width = stride
outpad_height, outpad_width = output_padding
assert outpad_height < stride_height and outpad_width < stride_width
assert inp_channels % groups == 0, f"input channels {inp_channels} must divide group size {groups}"

if inp_channels % groups != 0:
raise ValueError(f"input channels {inp_channels} must divide group size {groups}")

cfg.stride = stride
pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
padding, (kernel_height, kernel_width)
Expand Down Expand Up @@ -112,14 +115,14 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_
data_out = te.compute(
(batch, out_channels * groups, out_height, out_width),
lambda b, c, h, w: te.sum(
data[
b, c // out_channels * (inp_channels // groups) + dc, h + dh, w + dw
].astype(out_dtype)
data[b, c // out_channels * (inp_channels // groups) + dc, h + dh, w + dw].astype(
out_dtype
)
* kernel[
c // out_channels * (inp_channels // groups) + dc,
c % out_channels,
kernel_height - 1 - dh,
kernel_width - 1 - dw
kernel_width - 1 - dw,
].astype(out_dtype),
axis=[dc, dh, dw],
),
Expand Down

0 comments on commit 14f29df

Please sign in to comment.