Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 36 additions & 39 deletions python/tvm/topi/cuda/conv1d_transpose_ncw.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,29 +65,46 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype, output_p
out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + output_padding
pad_left = kernel_size - 1 - pad_left
pad_right = kernel_size - 1 - pad_right + output_padding
dilated_width = stride * (inp_width - 1) + 1
data = te.compute(
(batch, inp_channels, pad_left + dilated_width + pad_right),
padded_width = pad_left + inp_width + pad_right

padded_data = te.compute(
(batch, inp_channels, padded_width),
lambda n, c, x: tvm.tir.if_then_else(
tvm.tir.all(
x >= pad_left,
x < pad_left + dilated_width,
tvm.tir.indexmod(x - pad_left, stride).equal(0),
),
data[n, c, tvm.tir.indexdiv(x - pad_left, stride)],
tvm.tir.all(x >= pad_left, x < pad_left + inp_width),
data[n, c, x - pad_left],
tvm.tir.const(0.0, "float32"),
),
name="data_pad",
)

dc = te.reduce_axis((0, inp_channels), name="dc")
dw = te.reduce_axis((0, kernel_size), name="dw")
padded_kernel = te.compute(
(inp_channels, out_channels, kernel_size + stride - 1),
lambda ci, co, k: tvm.tir.if_then_else(
tvm.tir.all(k < kernel_size),
kernel[ci, co, kernel_size - k - 1],
tvm.tir.const(0.0, "float32"),
),
name="kernel_pad",
)

ci = te.reduce_axis((0, inp_channels), name="ci")
k = te.reduce_axis((0, tvm.tir.indexdiv(kernel_size + stride - 1, stride)), name="k")
border = pad_left * (stride - 1)

# Skip multiplication by 0 values in the input data inserted when stride is greater then 1.
# During multiplication of kernel by padded data:
# Kernel indices are: 0, 1 * stride, 2 * stride, ..., ceil(kernel_size / stride) plus
# data offset mod stride
data_out = te.compute(
(batch, out_channels, out_width),
lambda b, c, w: te.sum(
data[b, dc, w + dw].astype(out_dtype)
* kernel[dc, c, kernel_size - 1 - dw].astype(out_dtype),
axis=[dc, dw],
lambda b, co, w: te.sum(
padded_data[b, ci, tvm.tir.indexdiv(border + w + stride - 1, stride) + k].astype(
out_dtype
)
* padded_kernel[
ci, co, k * stride + tvm.tir.indexmod(stride - w - border, stride)
].astype(out_dtype),
axis=[ci, k],
),
tag="conv1d_transpose_ncw",
)
Expand Down Expand Up @@ -118,8 +135,8 @@ def schedule_conv1d_transpose_ncw(cfg, outs):

def _callback(op):
if op.tag == "conv1d_transpose_ncw":
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
padded_data = op.input_tensors[0]
padded_kernel = op.input_tensors[1]
conv = op.output(0)

##### space definition begin #####
Expand All @@ -139,9 +156,6 @@ def _callback(op):

##### space definition end #####

if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

if conv.op in s.outputs:
output = conv
OL = s.cache_write(conv, "local")
Expand All @@ -150,10 +164,8 @@ def _callback(op):
s[conv].set_scope("local")
OL = conv

# create cache stage
s[pad_data].set_scope("shared")
AA = pad_data
WW = s.cache_read(kernel, "shared", [OL])
s[padded_kernel].compute_inline()
s[padded_data].compute_inline()

# tile and bind spatial axes
n, f, x = s[output].op.axis
Expand All @@ -172,28 +184,13 @@ def _callback(op):

s[output].bind(tx, te.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tx)
# number of threads
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
n_tx = cfg["tile_x"].size[2]

# tile reduction axes
n, f, x = s[OL].op.axis
rc, rx = s[OL].op.reduce_axis
rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
s[OL].reorder(rco, rcm, rx, rci, n, f, x)

s[AA].compute_at(s[OL], rx)
s[WW].compute_at(s[OL], rx)

# cooperative fetching
for load in [AA, WW]:
n, f, x = s[load].op.axis
fused = s[load].fuse(f, x)
tz, fused = s[load].split(fused, nparts=n_tz)
tx, fused = s[load].split(fused, nparts=n_tx)
s[load].bind(tz, te.thread_axis("threadIdx.y"))
s[load].bind(tx, te.thread_axis("threadIdx.x"))

s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)

Expand Down
4 changes: 4 additions & 0 deletions tests/python/topi/python/test_topi_conv1d_transpose_ncw.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def test_conv1d_transpose_ncw():
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,))
verify_conv1d_transpose_ncw(1, 2, 1024, 1, 128, 128, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 128, 128, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 2, 2, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0, 3), (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1, 3), (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2, 3), (0,))
verify_conv1d_transpose_ncw(1, 257, 128, 1, 512, 128, 256, (0,))


if __name__ == "__main__":
Expand Down