Skip to content

Commit ce1045e

Browse files
Merge pull request #289 from FluxML/dg/groups
Support groups in DenseConvDims
2 parents 77fd3bf + 618817d commit ce1045e

File tree

7 files changed

+162
-35
lines changed

7 files changed

+162
-35
lines changed

src/conv.jl

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
4646

4747
########## STEP 1 ############
4848
"""
49-
conv(x, w; stride=1, pad=0, dilation=1, flipped=false)
49+
conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)
5050
5151
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
5252
in 1d/2d/3d convolutions respectively.
5353
"""
54-
function conv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N}
54+
function conv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false, groups = 1) where {T, N}
5555
stride = expand(Val(N-2), stride)
5656
pad = expand(Val(N-2), pad)
5757
dilation = expand(Val(N-2), dilation)
58-
cdims = DenseConvDims(x, w; stride=stride, padding=pad, dilation=dilation, flipkernel=flipped)
58+
cdims = DenseConvDims(x, w; stride=stride, padding=pad, dilation=dilation, flipkernel=flipped, groups = groups)
5959
return conv(x, w, cdims)
6060
end
6161

@@ -97,9 +97,10 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
9797
@eval begin
9898
function $(Symbol("$(name)$(backend)"))(
9999
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
100-
cdims::ConvDims; kwargs...) where {yT, wT, N}
100+
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}
101101
dx = similar(dy, input_size(cdims)..., channels_in(cdims),
102102
size(dy, N))
103+
103104
return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...)
104105
end
105106
end
@@ -111,8 +112,9 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
111112
function $(Symbol("∇conv_filter$(backend)"))(
112113
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
113114
cdims::ConvDims; kwargs...) where {xT, yT, N}
114-
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims),
115+
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims),
115116
channels_out(cdims))
117+
116118
return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...)
117119
end
118120
end
@@ -145,6 +147,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
145147
y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N},
146148
w::AbstractArray{wT,$N}, cdims::ConvDims;
147149
kwargs...) where {yT, xT, wT}
150+
148151
$(Symbol("$(front_name)$(backend)!"))(
149152
insert_singleton_spatial_dimension(y, $(5 - N)),
150153
insert_singleton_spatial_dimension(x, $(5 - N)),
@@ -161,6 +164,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
161164
end
162165
end
163166
end
167+
164168
#######################################
165169

166170

@@ -169,25 +173,106 @@ end
169173
# First, we will define mappings from the generic API names to our accelerated backend
170174
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
171175
# im2col + GEMM. Do so in a loop, here:
176+
177+
# These are the GEMM types we will accelerate with `im2col`
178+
const G = Union{[x[2] for x in gemm_datatype_mappings]...}
179+
172180
for (front_name, backend) in (
173181
# This maps from public, front-facing name, to internal backend name
174182
:conv => :im2col,
175-
:∇conv_data => :im2col,
176-
:∇conv_filter => :im2col,
183+
)
184+
185+
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
186+
@eval begin
187+
# im2col-accelerated function forwarding definition
188+
function $(Symbol("$(front_name)!"))(
189+
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
190+
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
191+
192+
x_cs = Iterators.partition(1:size(in1, 4),
193+
channels_in(cdims) ÷ groupcount(cdims))
194+
w_cs = Iterators.partition(1:size(in2, 5),
195+
channels_out(cdims) ÷ groupcount(cdims))
196+
cdims2 = basetype(C)(cdims,
197+
G = 1,
198+
C_in = channels_in(cdims) ÷ groupcount(cdims),
199+
C_out = channels_out(cdims) ÷ groupcount(cdims))
200+
201+
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
202+
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
203+
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
204+
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
205+
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
206+
end
207+
208+
return out
209+
end
210+
end
211+
end
212+
213+
# im2col-accelerated function forwarding definition
214+
function ∇conv_data!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
215+
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
216+
217+
dx_cs = Iterators.partition(1:size(out, 4),
218+
channels_in(cdims) ÷ groupcount(cdims))
219+
w_cs = Iterators.partition(1:size(in2, 5),
220+
channels_out(cdims) ÷ groupcount(cdims))
221+
dy_cs = Iterators.partition(1:size(in1, 4),
222+
channels_out(cdims) ÷ groupcount(cdims))
223+
cdims2 = basetype(C)(cdims,
224+
G = 1,
225+
C_in = channels_in(cdims) ÷ groupcount(cdims),
226+
C_out = channels_out(cdims) ÷ groupcount(cdims))
227+
228+
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
229+
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
230+
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
231+
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
232+
Threads.@spawn ∇conv_data_im2col!(dxv, dyv, wv, cdims2; kwargs...)
233+
end
234+
235+
return out
236+
end
237+
238+
function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
239+
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
240+
241+
dw_cs = Iterators.partition(1:size(out, 5),
242+
channels_out(cdims) ÷ groupcount(cdims))
243+
dy_cs = Iterators.partition(1:size(in2, 4),
244+
channels_out(cdims) ÷ groupcount(cdims))
245+
x_cs = Iterators.partition(1:size(in1, 4),
246+
channels_in(cdims) ÷ groupcount(cdims))
247+
cdims2 = basetype(C)(cdims,
248+
G = 1,
249+
C_in = channels_in(cdims) ÷ groupcount(cdims),
250+
C_out = channels_out(cdims) ÷ groupcount(cdims))
251+
252+
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
253+
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
254+
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
255+
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
256+
Threads.@spawn ∇conv_filter_im2col!(dw, x, dy, cdims2; kwargs...)
257+
end
258+
259+
return out
260+
end
261+
262+
263+
for (front_name, backend) in (
264+
# This maps from public, front-facing name, to internal backend name
177265
:depthwiseconv => :im2col,
178266
:∇depthwiseconv_data => :im2col,
179267
:∇depthwiseconv_filter => :im2col,
180268
)
181269

182-
# These are the GEMM types we will accelerate with `im2col`
183-
G = Union{[x[2] for x in gemm_datatype_mappings]...}
184-
185270
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
186271
@eval begin
187272
# im2col-accelerated function forwarding definition
188273
function $(Symbol("$(front_name)!"))(
189274
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
190-
in2::AbstractArray{T,5}, cdims::ConvDims; kwargs...) where {T <: $G}
275+
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
191276
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
192277
end
193278
end

src/dim_helpers/ConvDims.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ stride(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = S
3333
padding(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = P
3434
dilation(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = D
3535
flipkernel(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = F
36+
groupcount(c::ConvDims) = 1
3637

3738
"""
3839
im2col_dims(c::ConvDims)
@@ -131,5 +132,6 @@ function Base.show(io::IO, cdims::C) where {C <: ConvDims}
131132
P = padding(cdims)
132133
D = dilation(cdims)
133134
F = flipkernel(cdims)
134-
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F")
135+
G = groupcount(cdims)
136+
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F, groups: $G")
135137
end

src/dim_helpers/DenseConvDims.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,43 @@ export DenseConvDims
55
66
Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d.
77
"""
8-
struct DenseConvDims{N,K,C_in,C_out,S,P,D,F} <: ConvDims{N,S,P,D,F}
8+
struct DenseConvDims{N,K,C_in,C_out,G,S,P,D,F} <: ConvDims{N,S,P,D,F}
99
I::NTuple{N,Int}
1010
end
1111

1212
# Getters for the fields
1313
input_size(c::DenseConvDims) = c.I
1414
kernel_size(c::DenseConvDims{N,K,C_in,C_out}) where {N,K,C_in,C_out} = K
1515
channels_in(c::DenseConvDims{N,K,C_in,C_out}) where {N,K,C_in,C_out} = C_in::Int
16-
channels_out(c::DenseConvDims{N,K,C_in,C_out}) where {N,K,C_in,C_out} = C_out::Int
16+
channels_out(c::DenseConvDims{N,K,C_in,C_out,G}) where {N,K,C_in,C_out,G} = C_out::Int
17+
groupcount(c::DenseConvDims{N,K,C_in,C_out,G}) where {N,K,C_in,C_out,G} = G::Int
1718

1819
# Convenience wrapper to create DenseConvDims objects
1920
function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
20-
stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M
21+
stride=1, padding=0, dilation=1, flipkernel::Bool=false, groups = 1) where M
22+
2123
# Do common parameter validation
2224
stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)
2325

2426
# Ensure channels are equal
25-
if x_size[end-1] != w_size[end-1]
27+
if x_size[end-1] != w_size[end-1] * groups
2628
xs = x_size[end-1]
2729
ws = w_size[end-1]
2830
throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
2931
end
30-
32+
33+
# Ensure groups are valid
34+
if x_size[end-1] % w_size[end-1] != 0 || w_size[end] % groups != 0
35+
throw(DimensionMismatch("Group count should be divisble by input and output channels ($groups vs. $(w_size[end-1:end]))"))
36+
end
37+
3138
# The type parameters are what
3239
return DenseConvDims{
3340
M - 2,
3441
w_size[1:end-2],
3542
x_size[end-1],
3643
w_size[end],
44+
groups,
3745
stride,
3846
padding,
3947
dilation,
@@ -56,22 +64,25 @@ end
5664
# from the original progenitor object that it inherits shapes from.
5765
function DenseConvDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
5866
C_in=channels_in(c), C_out=channels_out(c), S=stride(c),
59-
P=padding(c), D=dilation(c), F=flipkernel(c))
60-
return DenseConvDims{N, K, C_in, C_out, S, P, D, F}(I)
67+
P=padding(c), D=dilation(c), F=flipkernel(c), G=groupcount(c))
68+
return DenseConvDims{N, K, C_in, C_out, G, S, P, D, F}(I)
6169
end
6270

6371
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}
6472
# First, check that channel counts are all correct:
65-
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
66-
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
67-
@assert w[M-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))")
68-
@assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")
73+
@assert x[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
74+
@assert y[M-1] == channels_out(cdims) ÷ groupcount(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
75+
@assert w[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))")
76+
@assert w[M] * groupcount(cdims) == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")
6977

7078
# Next, check that the spatial dimensions match up
7179
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
7280
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
7381
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")
7482

83+
# Check the groups match
84+
@assert channels_in(cdims) % groupcount(cdims) == 0 DimensionMismatch("Groups ($(groupcount(cdims))) should be divisble by input channels $(channels_in(cdims))")
85+
7586
# Finally, check that the batch size matches
7687
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
7788
end

test/conv.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using NNlib, Test
22
using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,
3-
stride, padding, dilation, flipkernel, output_size
3+
stride, padding, dilation, flipkernel, output_size,
4+
groupcount
45

56
@testset "ConvDims" begin
67
for T in (DenseConvDims, DepthwiseConvDims)
@@ -648,6 +649,33 @@ else
648649
@info "Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
649650
end
650651

652+
@testset "Grouped Convolutions" begin
653+
x′ = rand(Float32, 28, 28, 100, 2)
654+
w′ = rand(Float32, 3, 3, 20, 15)
655+
656+
@test_throws DimensionMismatch DenseConvDims(x′, w′)
657+
cdims = DenseConvDims(x′, w′, groups = 5)
658+
659+
@test groupcount(cdims) == 5
660+
661+
y = conv(x′, w′, cdims)
662+
_, back = Zygote.pullback((x, w) -> sum(conv(x, w, cdims)), x′, w′)
663+
gs_x, gs_w = back(1.f0)
664+
665+
666+
ips = Iterators.partition(1:100, 20)
667+
ops = Iterators.partition(1:15, 3)
668+
for (i,o) in zip(ips,ops)
669+
_, back_reg = Zygote.pullback((x, w) -> sum(conv(x, w)), x′[:,:,i,:], w′[:,:,:,o])
670+
gs_x_reg, gs_w_reg = back_reg(1.f0)
671+
@test conv(x′[:,:,i,:], w′[:,:,:,o]) y[:,:,o,:]
672+
@test gs_x_reg gs_x[:,:,i,:]
673+
@test gs_w_reg gs_w[:,:,:,o]
674+
end
675+
676+
# Currently hangs due to a FiniteDifferences issue
677+
@test_skip gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′)
678+
end
651679

652680
@testset "conv_wrapper" begin
653681
x = rand(10, 10, 3, 10)

test/padding.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect
2+
13
@testset "padding constant" begin
24
x = rand(2, 2, 2)
35

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ end
5757
include("utils.jl")
5858
end
5959

60-
6160
if VERSION >= v"1.6" && CUDA.functional()
6261
if get(ENV, "NNLIB_TEST_CUDA", "false") == "true"
6362
import Pkg

test/test_utils.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ given by Zygote. `f` has to be a scalar valued function.
1010
1111
Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly defined.
1212
"""
13-
function gradtest(f, xs...; atol=1e-6, rtol=1e-6, fkwargs=NamedTuple(),
14-
check_rrule=false,
15-
fdm=:central,
16-
check_broadcast=false,
17-
skip=false, broken=false)
13+
function gradtest(f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs=NamedTuple(),
14+
check_rrule = false,
15+
fdm = :central,
16+
check_broadcast = false,
17+
skip = false, broken = false)
1818
# TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166
1919
# is merged
2020
if check_rrule
21-
test_rrule(f, xs...; fkwargs=fkwargs)
21+
test_rrule(f, xs...; fkwargs = fkwargs)
2222
end
2323

2424
if check_broadcast
@@ -43,14 +43,14 @@ function gradtest(f, xs...; atol=1e-6, rtol=1e-6, fkwargs=NamedTuple(),
4343
y_ad, pull = Zygote.pullback(h, xs...)
4444
gs_ad = pull(one(y_ad))
4545

46-
@test y_true y_ad atol=atol rtol=rtol
46+
@test y_true y_ad atol = atol rtol = rtol
4747
for (g_ad, g_fd) in zip(gs_ad, gs_fd)
4848
if skip
49-
@test_skip g_ad g_fd atol=atol rtol=rtol
49+
@test_skip g_ad g_fd atol = atol rtol = rtol
5050
elseif broken
51-
@test_broken g_ad g_fd atol=atol rtol=rtol
51+
@test_broken g_ad g_fd atol = atol rtol = rtol
5252
else
53-
@test g_ad g_fd atol=atol rtol=rtol
53+
@test g_ad g_fd atol = atol rtol = rtol
5454
end
5555
end
5656
return true

0 commit comments

Comments
 (0)