@@ -46,16 +46,16 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
46
46
47
47
# ######### STEP 1 ############
48
48
"""
49
- conv(x, w; stride= 1, pad= 0, dilation= 1, flipped= false)
49
+ conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1 )
50
50
51
51
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
52
52
in 1d/2d/3d convolutions respectively.
53
53
"""
54
- function conv (x, w:: AbstractArray{T, N} ; stride= 1 , pad= 0 , dilation= 1 , flipped= false ) where {T, N}
54
+ function conv (x, w:: AbstractArray{T, N} ; stride= 1 , pad= 0 , dilation= 1 , flipped= false , groups = 1 ) where {T, N}
55
55
stride = expand (Val (N- 2 ), stride)
56
56
pad = expand (Val (N- 2 ), pad)
57
57
dilation = expand (Val (N- 2 ), dilation)
58
- cdims = DenseConvDims (x, w; stride= stride, padding= pad, dilation= dilation, flipkernel= flipped)
58
+ cdims = DenseConvDims (x, w; stride= stride, padding= pad, dilation= dilation, flipkernel= flipped, groups = groups )
59
59
return conv (x, w, cdims)
60
60
end
61
61
@@ -97,9 +97,10 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
97
97
@eval begin
98
98
function $ (Symbol (" $(name)$(backend) " ))(
99
99
dy:: AbstractArray{yT,N} , w:: AbstractArray{wT,N} ,
100
- cdims:: ConvDims ; kwargs... ) where {yT, wT, N}
100
+ cdims:: C ; kwargs... ) where {yT, wT, N, C <: ConvDims }
101
101
dx = similar (dy, input_size (cdims)... , channels_in (cdims),
102
102
size (dy, N))
103
+
103
104
return $ (Symbol (" $(name)$(backend) !" ))(dx, dy, w, cdims; kwargs... )
104
105
end
105
106
end
@@ -111,8 +112,9 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
111
112
function $ (Symbol (" ∇conv_filter$(backend) " ))(
112
113
x:: AbstractArray{xT,N} , dy:: AbstractArray{yT,N} ,
113
114
cdims:: ConvDims ; kwargs... ) where {xT, yT, N}
114
- dw = similar (dy, kernel_size (cdims)... , channels_in (cdims),
115
+ dw = similar (dy, kernel_size (cdims)... , channels_in (cdims) ÷ groupcount (cdims) ,
115
116
channels_out (cdims))
117
+
116
118
return $ (Symbol (" ∇conv_filter$(backend) !" ))(dw, x, dy, cdims; kwargs... )
117
119
end
118
120
end
@@ -145,6 +147,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
145
147
y:: AbstractArray{yT,$N} , x:: AbstractArray{xT,$N} ,
146
148
w:: AbstractArray{wT,$N} , cdims:: ConvDims ;
147
149
kwargs... ) where {yT, xT, wT}
150
+
148
151
$ (Symbol (" $(front_name)$(backend) !" ))(
149
152
insert_singleton_spatial_dimension (y, $ (5 - N)),
150
153
insert_singleton_spatial_dimension (x, $ (5 - N)),
@@ -161,6 +164,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
161
164
end
162
165
end
163
166
end
167
+
164
168
# ######################################
165
169
166
170
@@ -169,25 +173,106 @@ end
169
173
# First, we will define mappings from the generic API names to our accelerated backend
170
174
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
171
175
# im2col + GEMM. Do so in a loop, here:
176
+
177
+ # These are the GEMM types we will accelerate with `im2col`
178
+ const G = Union{[x[2 ] for x in gemm_datatype_mappings]. .. }
179
+
172
180
for (front_name, backend) in (
173
181
# This maps from public, front-facing name, to internal backend name
174
182
:conv => :im2col ,
175
- :∇conv_data => :im2col ,
176
- :∇conv_filter => :im2col ,
183
+ )
184
+
185
+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
186
+ @eval begin
187
+ # im2col-accelerated function forwarding definition
188
+ function $ (Symbol (" $(front_name) !" ))(
189
+ out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
190
+ in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
191
+
192
+ x_cs = Iterators. partition (1 : size (in1, 4 ),
193
+ channels_in (cdims) ÷ groupcount (cdims))
194
+ w_cs = Iterators. partition (1 : size (in2, 5 ),
195
+ channels_out (cdims) ÷ groupcount (cdims))
196
+ cdims2 = basetype (C)(cdims,
197
+ G = 1 ,
198
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
199
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
200
+
201
+ Threads. @sync for (xc, wc) in zip (x_cs, w_cs)
202
+ x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
203
+ w = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
204
+ y = @view out[ntuple (i -> i == 4 ? wc : Colon (), 5 )... ]
205
+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(y, x, w, cdims2; kwargs... )
206
+ end
207
+
208
+ return out
209
+ end
210
+ end
211
+ end
212
+
213
+ # im2col-accelerated function forwarding definition
214
+ function ∇conv_data! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
215
+ in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
216
+
217
+ dx_cs = Iterators. partition (1 : size (out, 4 ),
218
+ channels_in (cdims) ÷ groupcount (cdims))
219
+ w_cs = Iterators. partition (1 : size (in2, 5 ),
220
+ channels_out (cdims) ÷ groupcount (cdims))
221
+ dy_cs = Iterators. partition (1 : size (in1, 4 ),
222
+ channels_out (cdims) ÷ groupcount (cdims))
223
+ cdims2 = basetype (C)(cdims,
224
+ G = 1 ,
225
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
226
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
227
+
228
+ Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
229
+ dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
230
+ dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
231
+ wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
232
+ Threads. @spawn ∇conv_data_im2col! (dxv, dyv, wv, cdims2; kwargs... )
233
+ end
234
+
235
+ return out
236
+ end
237
+
238
+ function ∇conv_filter! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
239
+ in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
240
+
241
+ dw_cs = Iterators. partition (1 : size (out, 5 ),
242
+ channels_out (cdims) ÷ groupcount (cdims))
243
+ dy_cs = Iterators. partition (1 : size (in2, 4 ),
244
+ channels_out (cdims) ÷ groupcount (cdims))
245
+ x_cs = Iterators. partition (1 : size (in1, 4 ),
246
+ channels_in (cdims) ÷ groupcount (cdims))
247
+ cdims2 = basetype (C)(cdims,
248
+ G = 1 ,
249
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
250
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
251
+
252
+ Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
253
+ x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
254
+ dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
255
+ dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
256
+ Threads. @spawn ∇conv_filter_im2col! (dw, x, dy, cdims2; kwargs... )
257
+ end
258
+
259
+ return out
260
+ end
261
+
262
+
263
+ for (front_name, backend) in (
264
+ # This maps from public, front-facing name, to internal backend name
177
265
:depthwiseconv => :im2col ,
178
266
:∇depthwiseconv_data => :im2col ,
179
267
:∇depthwiseconv_filter => :im2col ,
180
268
)
181
269
182
- # These are the GEMM types we will accelerate with `im2col`
183
- G = Union{[x[2 ] for x in gemm_datatype_mappings]. .. }
184
-
185
270
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
186
271
@eval begin
187
272
# im2col-accelerated function forwarding definition
188
273
function $ (Symbol (" $(front_name) !" ))(
189
274
out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
190
- in2:: AbstractArray{T,5} , cdims:: ConvDims ; kwargs... ) where {T <: $G }
275
+ in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
191
276
$ (Symbol (" $(front_name) _$(backend) !" ))(out, in1, in2, cdims; kwargs... )
192
277
end
193
278
end
0 commit comments