Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support groups in DenseConvDims #289

Merged
merged 36 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5ffabbc
add groups to DenseConvDims:
Mar 9, 2021
7db08a9
check dims match
Mar 9, 2021
e3f2816
groups -> groupcount
Mar 9, 2021
3c977ac
add check for groups
Mar 9, 2021
31f30ce
add groups to ConvDims
Mar 9, 2021
eb3b60d
assert both channels are divisible
Mar 9, 2021
cb80891
correct groups check
Mar 10, 2021
e5d3a03
correct name
Mar 11, 2021
b782848
check groups in various places
Mar 11, 2021
c4fe97f
pick groups with local memory
Mar 13, 2021
882c343
cleanup
Mar 13, 2021
115784b
working version
Mar 13, 2021
563488b
use non allocating version
Mar 25, 2021
50c5e89
use threads optionally
Mar 25, 2021
092c49b
eval nameof cdims
DhairyaLGandhi Mar 29, 2021
6b439c0
spawn on conv
DhairyaLGandhi Mar 29, 2021
f0bb94d
explicitly return out
DhairyaLGandhi Apr 5, 2021
d660b20
check groupcount
DhairyaLGandhi Apr 6, 2021
a5bb54a
add depthwise dispatches
DhairyaLGandhi Apr 10, 2021
cf1b457
cleanup
DhairyaLGandhi Apr 10, 2021
de2a3cf
correct names
DhairyaLGandhi Apr 12, 2021
bba416b
add backwards pass support
DhairyaLGandhi May 17, 2021
45663cf
add dense conv dims + groups tests
DhairyaLGandhi Jun 9, 2021
8d1398c
Add grouped conv tests
DhairyaLGandhi Jun 16, 2021
0c1dcff
fix tests
DhairyaLGandhi Jun 16, 2021
c95866f
fix tests
DhairyaLGandhi Jun 16, 2021
0b1d843
reviews
DhairyaLGandhi Jul 15, 2021
b301225
conflicts
DhairyaLGandhi Jul 15, 2021
4288612
Merge branch 'master' into dg/groups
DhairyaLGandhi Jul 15, 2021
b97574d
use basetype
DhairyaLGandhi Jul 15, 2021
161d137
Merge branch 'dg/groups' of https://github.com/FluxML/NNlib.jl into d…
DhairyaLGandhi Jul 15, 2021
af16713
add even more tests
DhairyaLGandhi Jul 15, 2021
c6548b6
whitespace + docs
DhairyaLGandhi Jul 15, 2021
80f70e1
reviews
DhairyaLGandhi Jul 15, 2021
e014700
revert channels_out
DhairyaLGandhi Jul 15, 2021
618817d
whitespace
DhairyaLGandhi Jul 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 96 additions & 11 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!

########## STEP 1 ############
"""
conv(x, w; stride=1, pad=0, dilation=1, flipped=false)
conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)

Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
in 1d/2d/3d convolutions respectively.
"""
function conv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N}
function conv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false, groups = 1) where {T, N}
stride = expand(Val(N-2), stride)
pad = expand(Val(N-2), pad)
dilation = expand(Val(N-2), dilation)
cdims = DenseConvDims(x, w; stride=stride, padding=pad, dilation=dilation, flipkernel=flipped)
cdims = DenseConvDims(x, w; stride=stride, padding=pad, dilation=dilation, flipkernel=flipped, groups = groups)
return conv(x, w, cdims)
end

Expand Down Expand Up @@ -97,9 +97,10 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
@eval begin
function $(Symbol("$(name)$(backend)"))(
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
cdims::ConvDims; kwargs...) where {yT, wT, N}
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}
dx = similar(dy, input_size(cdims)..., channels_in(cdims),
size(dy, N))

return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...)
end
end
Expand All @@ -111,8 +112,9 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
function $(Symbol("∇conv_filter$(backend)"))(
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
cdims::ConvDims; kwargs...) where {xT, yT, N}
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims),
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims),
channels_out(cdims))

return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...)
end
end
Expand Down Expand Up @@ -145,6 +147,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N},
w::AbstractArray{wT,$N}, cdims::ConvDims;
kwargs...) where {yT, xT, wT}

$(Symbol("$(front_name)$(backend)!"))(
insert_singleton_spatial_dimension(y, $(5 - N)),
insert_singleton_spatial_dimension(x, $(5 - N)),
Expand All @@ -161,6 +164,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
end
end
end

#######################################


Expand All @@ -169,25 +173,106 @@ end
# First, we will define mappings from the generic API names to our accelerated backend
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
# im2col + GEMM. Do so in a loop, here:

# These are the GEMM types we will accelerate with `im2col`
const G = Union{[x[2] for x in gemm_datatype_mappings]...}

for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:conv => :im2col,
:∇conv_data => :im2col,
:∇conv_filter => :im2col,
)

# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@eval begin
# im2col-accelerated function forwarding definition
function $(Symbol("$(front_name)!"))(
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}

x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
end

return out
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
end
end
end

# im2col-accelerated function forwarding definition
function ∇conv_data!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}

dx_cs = Iterators.partition(1:size(out, 4),
channels_in(cdims) ÷ groupcount(cdims))
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in1, 4),
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
Threads.@spawn ∇conv_data_im2col!(dxv, dyv, wv, cdims2; kwargs...)
end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

return out
end

function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}

dw_cs = Iterators.partition(1:size(out, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in2, 4),
channels_out(cdims) ÷ groupcount(cdims))
x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
Threads.@spawn ∇conv_filter_im2col!(dw, x, dy, cdims2; kwargs...)
end

return out
end


for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:depthwiseconv => :im2col,
:∇depthwiseconv_data => :im2col,
:∇depthwiseconv_filter => :im2col,
)

# These are the GEMM types we will accelerate with `im2col`
G = Union{[x[2] for x in gemm_datatype_mappings]...}

# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@eval begin
# im2col-accelerated function forwarding definition
function $(Symbol("$(front_name)!"))(
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::ConvDims; kwargs...) where {T <: $G}
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
end
end
Expand Down
4 changes: 3 additions & 1 deletion src/dim_helpers/ConvDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ stride(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = S
padding(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = P
dilation(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = D
flipkernel(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = F
groupcount(c::ConvDims) = 1

"""
im2col_dims(c::ConvDims)
Expand Down Expand Up @@ -131,5 +132,6 @@ function Base.show(io::IO, cdims::C) where {C <: ConvDims}
P = padding(cdims)
D = dilation(cdims)
F = flipkernel(cdims)
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F")
G = groupcount(cdims)
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F, groups: $G")
end
33 changes: 22 additions & 11 deletions src/dim_helpers/DenseConvDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,43 @@ export DenseConvDims

Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d.
"""
struct DenseConvDims{N,K,C_in,C_out,S,P,D,F} <: ConvDims{N,S,P,D,F}
struct DenseConvDims{N,K,C_in,C_out,G,S,P,D,F} <: ConvDims{N,S,P,D,F}
I::NTuple{N,Int}
end

# Getters for the fields
input_size(c::DenseConvDims) = c.I
kernel_size(c::DenseConvDims{N,K,C_in,C_out}) where {N,K,C_in,C_out} = K
channels_in(c::DenseConvDims{N,K,C_in,C_out}) where {N,K,C_in,C_out} = C_in::Int
channels_out(c::DenseConvDims{N,K,C_in,C_out}) where {N,K,C_in,C_out} = C_out::Int
channels_out(c::DenseConvDims{N,K,C_in,C_out,G}) where {N,K,C_in,C_out,G} = (C_out * G)::Int
groupcount(c::DenseConvDims{N,K,C_in,C_out,G}) where {N,K,C_in,C_out,G} = G::Int

# Convenience wrapper to create DenseConvDims objects
function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M
stride=1, padding=0, dilation=1, flipkernel::Bool=false, groups = 1) where M

# Do common parameter validation
stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)

# Ensure channels are equal
if x_size[end-1] != w_size[end-1]
if x_size[end-1] != w_size[end-1] * groups
xs = x_size[end-1]
ws = w_size[end-1]
throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
end


# Ensure groups are valid
if x_size[end-1] % w_size[end-1] != 0 || w_size[end] % groups != 0
throw(DimensionMismatch("Group count should be divisble by input and output channels ($groups vs. $(w_size[end-1:end]))"))
end

# The type parameters are what
return DenseConvDims{
M - 2,
w_size[1:end-2],
x_size[end-1],
w_size[end],
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
groups,
stride,
padding,
dilation,
Expand All @@ -56,22 +64,25 @@ end
# from the original progenitor object that it inherits shapes from.
function DenseConvDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
C_in=channels_in(c), C_out=channels_out(c), S=stride(c),
P=padding(c), D=dilation(c), F=flipkernel(c))
return DenseConvDims{N, K, C_in, C_out, S, P, D, F}(I)
P=padding(c), D=dilation(c), F=flipkernel(c), G=groupcount(c))
return DenseConvDims{N, K, C_in, C_out, G, S, P, D, F}(I)
end

function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}
# First, check that channel counts are all correct:
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
@assert w[M-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))")
@assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")
@assert x[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
@assert y[M-1] == channels_out(cdims) ÷ groupcount(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
@assert w[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))")
@assert w[M] * groupcount(cdims) == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")

# Next, check that the spatial dimensions match up
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")

# Check the groups match
@assert channels_in(cdims) % groupcount(cdims) == 0 DimensionMismatch("Groups ($(groupcount(cdims))) should be divisble by input channels $(channels_in(cdims))")
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved

# Finally, check that the batch size matches
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
end
30 changes: 29 additions & 1 deletion test/conv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using NNlib, Test
using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,
stride, padding, dilation, flipkernel, output_size
stride, padding, dilation, flipkernel, output_size,
groupcount

@testset "ConvDims" begin
for T in (DenseConvDims, DepthwiseConvDims)
Expand Down Expand Up @@ -648,6 +649,33 @@ else
@info "Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
end

@testset "Grouped Convolutions" begin
x′ = rand(Float32, 28, 28, 100, 2)
w′ = rand(Float32, 3, 3, 20, 15)

@test_throws DimensionMismatch DenseConvDims(x′, w′)
cdims = DenseConvDims(x′, w′, groups = 5)

@test groupcount(cdims) == 5

y = conv(x′, w′, cdims)
_, back = Zygote.pullback((x, w) -> sum(conv(x, w, cdims)), x′, w′)
gs_x, gs_w = back(1.f0)


ips = Iterators.partition(1:100, 20)
ops = Iterators.partition(1:15, 3)
for (i,o) in zip(ips,ops)
_, back_reg = Zygote.pullback((x, w) -> sum(conv(x, w)), x′[:,:,i,:], w′[:,:,:,o])
gs_x_reg, gs_w_reg = back_reg(1.f0)
@test conv(x′[:,:,i,:], w′[:,:,:,o]) ≈ y[:,:,o,:]
@test gs_x_reg ≈ gs_x[:,:,i,:]
@test gs_w_reg ≈ gs_w[:,:,:,o]
end

# Currently hangs due to a FiniteDifferences issue
@test_skip gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end

@testset "conv_wrapper" begin
x = rand(10, 10, 3, 10)
Expand Down
2 changes: 2 additions & 0 deletions test/padding.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect

@testset "padding constant" begin
x = rand(2, 2, 2)

Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ end
include("utils.jl")
end


if VERSION >= v"1.6" && CUDA.functional()
if get(ENV, "NNLIB_TEST_CUDA", "false") == "true"
import Pkg
Expand Down
20 changes: 10 additions & 10 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ given by Zygote. `f` has to be a scalar valued function.

Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly defined.
"""
function gradtest(f, xs...; atol=1e-6, rtol=1e-6, fkwargs=NamedTuple(),
check_rrule=false,
fdm=:central,
check_broadcast=false,
skip=false, broken=false)
function gradtest(f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs=NamedTuple(),
check_rrule = false,
fdm = :central,
check_broadcast = false,
skip = false, broken = false)
# TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166
# is merged
if check_rrule
test_rrule(f, xs...; fkwargs=fkwargs)
test_rrule(f, xs...; fkwargs = fkwargs)
end

if check_broadcast
Expand All @@ -43,14 +43,14 @@ function gradtest(f, xs...; atol=1e-6, rtol=1e-6, fkwargs=NamedTuple(),
y_ad, pull = Zygote.pullback(h, xs...)
gs_ad = pull(one(y_ad))

@test y_true ≈ y_ad atol=atol rtol=rtol
@test y_true ≈ y_ad atol = atol rtol = rtol
for (g_ad, g_fd) in zip(gs_ad, gs_fd)
if skip
@test_skip g_ad ≈ g_fd atol=atol rtol=rtol
@test_skip g_ad ≈ g_fd atol = atol rtol = rtol
elseif broken
@test_broken g_ad ≈ g_fd atol=atol rtol=rtol
@test_broken g_ad ≈ g_fd atol = atol rtol = rtol
else
@test g_ad ≈ g_fd atol=atol rtol=rtol
@test g_ad ≈ g_fd atol = atol rtol = rtol
end
end
return true
Expand Down