@@ -65,29 +65,46 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype, output_p
65
65
out_width = (inp_width - 1 ) * stride + kernel_size - pad_left - pad_right + output_padding
66
66
pad_left = kernel_size - 1 - pad_left
67
67
pad_right = kernel_size - 1 - pad_right + output_padding
68
- dilated_width = stride * (inp_width - 1 ) + 1
69
- data = te .compute (
70
- (batch , inp_channels , pad_left + dilated_width + pad_right ),
68
+ padded_width = pad_left + inp_width + pad_right
69
+
70
+ padded_data = te .compute (
71
+ (batch , inp_channels , padded_width ),
71
72
lambda n , c , x : tvm .tir .if_then_else (
72
- tvm .tir .all (
73
- x >= pad_left ,
74
- x < pad_left + dilated_width ,
75
- tvm .tir .indexmod (x - pad_left , stride ).equal (0 ),
76
- ),
77
- data [n , c , tvm .tir .indexdiv (x - pad_left , stride )],
73
+ tvm .tir .all (x >= pad_left , x < pad_left + inp_width ),
74
+ data [n , c , x - pad_left ],
78
75
tvm .tir .const (0.0 , "float32" ),
79
76
),
80
77
name = "data_pad" ,
81
78
)
82
79
83
- dc = te .reduce_axis ((0 , inp_channels ), name = "dc" )
84
- dw = te .reduce_axis ((0 , kernel_size ), name = "dw" )
80
+ padded_kernel = te .compute (
81
+ (inp_channels , out_channels , kernel_size + stride - 1 ),
82
+ lambda ci , co , k : tvm .tir .if_then_else (
83
+ tvm .tir .all (k < kernel_size ),
84
+ kernel [ci , co , kernel_size - k - 1 ],
85
+ tvm .tir .const (0.0 , "float32" ),
86
+ ),
87
+ name = "kernel_pad" ,
88
+ )
89
+
90
+ ci = te .reduce_axis ((0 , inp_channels ), name = "ci" )
91
+ k = te .reduce_axis ((0 , tvm .tir .indexdiv (kernel_size + stride - 1 , stride )), name = "k" )
92
+ border = pad_left * (stride - 1 )
93
+
94
+ # Skip multiplication by 0 values in the input data inserted when stride is greater then 1.
95
+ # During multiplication of kernel by padded data:
96
+ # Kernel indices are: 0, 1 * stride, 2 * stride, ..., ceil(kernel_size / stride) plus
97
+ # data offset mod stride
85
98
data_out = te .compute (
86
99
(batch , out_channels , out_width ),
87
- lambda b , c , w : te .sum (
88
- data [b , dc , w + dw ].astype (out_dtype )
89
- * kernel [dc , c , kernel_size - 1 - dw ].astype (out_dtype ),
90
- axis = [dc , dw ],
100
+ lambda b , co , w : te .sum (
101
+ padded_data [b , ci , tvm .tir .indexdiv (border + w + stride - 1 , stride ) + k ].astype (
102
+ out_dtype
103
+ )
104
+ * padded_kernel [
105
+ ci , co , k * stride + tvm .tir .indexmod (stride - w - border , stride )
106
+ ].astype (out_dtype ),
107
+ axis = [ci , k ],
91
108
),
92
109
tag = "conv1d_transpose_ncw" ,
93
110
)
@@ -118,8 +135,8 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
118
135
119
136
def _callback (op ):
120
137
if op .tag == "conv1d_transpose_ncw" :
121
- pad_data = op .input_tensors [0 ]
122
- kernel = op .input_tensors [1 ]
138
+ padded_data = op .input_tensors [0 ]
139
+ padded_kernel = op .input_tensors [1 ]
123
140
conv = op .output (0 )
124
141
125
142
##### space definition begin #####
@@ -139,9 +156,6 @@ def _callback(op):
139
156
140
157
##### space definition end #####
141
158
142
- if isinstance (kernel .op , tvm .te .ComputeOp ) and "dilate" in kernel .op .tag :
143
- s [kernel ].compute_inline ()
144
-
145
159
if conv .op in s .outputs :
146
160
output = conv
147
161
OL = s .cache_write (conv , "local" )
@@ -150,10 +164,8 @@ def _callback(op):
150
164
s [conv ].set_scope ("local" )
151
165
OL = conv
152
166
153
- # create cache stage
154
- s [pad_data ].set_scope ("shared" )
155
- AA = pad_data
156
- WW = s .cache_read (kernel , "shared" , [OL ])
167
+ s [padded_kernel ].compute_inline ()
168
+ s [padded_data ].compute_inline ()
157
169
158
170
# tile and bind spatial axes
159
171
n , f , x = s [output ].op .axis
@@ -172,28 +184,13 @@ def _callback(op):
172
184
173
185
s [output ].bind (tx , te .thread_axis ("threadIdx.x" ))
174
186
s [OL ].compute_at (s [output ], tx )
175
- # number of threads
176
- n_tz = cfg ["tile_n" ].size [2 ] * cfg ["tile_f" ].size [2 ]
177
- n_tx = cfg ["tile_x" ].size [2 ]
178
187
179
188
# tile reduction axes
180
189
n , f , x = s [OL ].op .axis
181
190
rc , rx = s [OL ].op .reduce_axis
182
191
rco , rcm , rci = cfg ["tile_rc" ].apply (s , OL , rc )
183
192
s [OL ].reorder (rco , rcm , rx , rci , n , f , x )
184
193
185
- s [AA ].compute_at (s [OL ], rx )
186
- s [WW ].compute_at (s [OL ], rx )
187
-
188
- # cooperative fetching
189
- for load in [AA , WW ]:
190
- n , f , x = s [load ].op .axis
191
- fused = s [load ].fuse (f , x )
192
- tz , fused = s [load ].split (fused , nparts = n_tz )
193
- tx , fused = s [load ].split (fused , nparts = n_tx )
194
- s [load ].bind (tz , te .thread_axis ("threadIdx.y" ))
195
- s [load ].bind (tx , te .thread_axis ("threadIdx.x" ))
196
-
197
194
s [output ].pragma (kernel_scope , "auto_unroll_max_step" , cfg ["auto_unroll_max_step" ].val )
198
195
s [output ].pragma (kernel_scope , "unroll_explicit" , cfg ["unroll_explicit" ].val )
199
196
0 commit comments