Skip to content

Commit f0979e4

Browse files
Alex GladkovUbuntu
andauthored
conv1d_transpose speedup. (#6840)
Improve performance of transposed convolution by avoiding redundant multiplication by zero values from dilated data. Co-authored-by: Ubuntu <ubuntu@ip-172-31-74-104.ec2.internal>
1 parent ff9c480 commit f0979e4

File tree

2 files changed

+40
-39
lines changed

2 files changed

+40
-39
lines changed

python/tvm/topi/cuda/conv1d_transpose_ncw.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -65,29 +65,46 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype, output_p
6565
out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + output_padding
6666
pad_left = kernel_size - 1 - pad_left
6767
pad_right = kernel_size - 1 - pad_right + output_padding
68-
dilated_width = stride * (inp_width - 1) + 1
69-
data = te.compute(
70-
(batch, inp_channels, pad_left + dilated_width + pad_right),
68+
padded_width = pad_left + inp_width + pad_right
69+
70+
padded_data = te.compute(
71+
(batch, inp_channels, padded_width),
7172
lambda n, c, x: tvm.tir.if_then_else(
72-
tvm.tir.all(
73-
x >= pad_left,
74-
x < pad_left + dilated_width,
75-
tvm.tir.indexmod(x - pad_left, stride).equal(0),
76-
),
77-
data[n, c, tvm.tir.indexdiv(x - pad_left, stride)],
73+
tvm.tir.all(x >= pad_left, x < pad_left + inp_width),
74+
data[n, c, x - pad_left],
7875
tvm.tir.const(0.0, "float32"),
7976
),
8077
name="data_pad",
8178
)
8279

83-
dc = te.reduce_axis((0, inp_channels), name="dc")
84-
dw = te.reduce_axis((0, kernel_size), name="dw")
80+
padded_kernel = te.compute(
81+
(inp_channels, out_channels, kernel_size + stride - 1),
82+
lambda ci, co, k: tvm.tir.if_then_else(
83+
tvm.tir.all(k < kernel_size),
84+
kernel[ci, co, kernel_size - k - 1],
85+
tvm.tir.const(0.0, "float32"),
86+
),
87+
name="kernel_pad",
88+
)
89+
90+
ci = te.reduce_axis((0, inp_channels), name="ci")
91+
k = te.reduce_axis((0, tvm.tir.indexdiv(kernel_size + stride - 1, stride)), name="k")
92+
border = pad_left * (stride - 1)
93+
94+
# Skip multiplication by 0 values in the input data inserted when stride is greater then 1.
95+
# During multiplication of kernel by padded data:
96+
# Kernel indices are: 0, 1 * stride, 2 * stride, ..., ceil(kernel_size / stride) plus
97+
# data offset mod stride
8598
data_out = te.compute(
8699
(batch, out_channels, out_width),
87-
lambda b, c, w: te.sum(
88-
data[b, dc, w + dw].astype(out_dtype)
89-
* kernel[dc, c, kernel_size - 1 - dw].astype(out_dtype),
90-
axis=[dc, dw],
100+
lambda b, co, w: te.sum(
101+
padded_data[b, ci, tvm.tir.indexdiv(border + w + stride - 1, stride) + k].astype(
102+
out_dtype
103+
)
104+
* padded_kernel[
105+
ci, co, k * stride + tvm.tir.indexmod(stride - w - border, stride)
106+
].astype(out_dtype),
107+
axis=[ci, k],
91108
),
92109
tag="conv1d_transpose_ncw",
93110
)
@@ -118,8 +135,8 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
118135

119136
def _callback(op):
120137
if op.tag == "conv1d_transpose_ncw":
121-
pad_data = op.input_tensors[0]
122-
kernel = op.input_tensors[1]
138+
padded_data = op.input_tensors[0]
139+
padded_kernel = op.input_tensors[1]
123140
conv = op.output(0)
124141

125142
##### space definition begin #####
@@ -139,9 +156,6 @@ def _callback(op):
139156

140157
##### space definition end #####
141158

142-
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
143-
s[kernel].compute_inline()
144-
145159
if conv.op in s.outputs:
146160
output = conv
147161
OL = s.cache_write(conv, "local")
@@ -150,10 +164,8 @@ def _callback(op):
150164
s[conv].set_scope("local")
151165
OL = conv
152166

153-
# create cache stage
154-
s[pad_data].set_scope("shared")
155-
AA = pad_data
156-
WW = s.cache_read(kernel, "shared", [OL])
167+
s[padded_kernel].compute_inline()
168+
s[padded_data].compute_inline()
157169

158170
# tile and bind spatial axes
159171
n, f, x = s[output].op.axis
@@ -172,28 +184,13 @@ def _callback(op):
172184

173185
s[output].bind(tx, te.thread_axis("threadIdx.x"))
174186
s[OL].compute_at(s[output], tx)
175-
# number of threads
176-
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
177-
n_tx = cfg["tile_x"].size[2]
178187

179188
# tile reduction axes
180189
n, f, x = s[OL].op.axis
181190
rc, rx = s[OL].op.reduce_axis
182191
rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
183192
s[OL].reorder(rco, rcm, rx, rci, n, f, x)
184193

185-
s[AA].compute_at(s[OL], rx)
186-
s[WW].compute_at(s[OL], rx)
187-
188-
# cooperative fetching
189-
for load in [AA, WW]:
190-
n, f, x = s[load].op.axis
191-
fused = s[load].fuse(f, x)
192-
tz, fused = s[load].split(fused, nparts=n_tz)
193-
tx, fused = s[load].split(fused, nparts=n_tx)
194-
s[load].bind(tz, te.thread_axis("threadIdx.y"))
195-
s[load].bind(tx, te.thread_axis("threadIdx.x"))
196-
197194
s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
198195
s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
199196

tests/python/topi/python/test_topi_conv1d_transpose_ncw.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,13 @@ def test_conv1d_transpose_ncw():
9191
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,))
9292
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,))
9393
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,))
94+
verify_conv1d_transpose_ncw(1, 2, 1024, 1, 128, 128, 0, (0,))
95+
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 128, 128, 0, (0,))
96+
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 2, 2, 0, (0,))
9497
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0, 3), (0,))
9598
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1, 3), (0,))
9699
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2, 3), (0,))
100+
verify_conv1d_transpose_ncw(1, 257, 128, 1, 512, 128, 256, (0,))
97101

98102

99103
if __name__ == "__main__":

0 commit comments

Comments
 (0)