diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 5f2ab3cea0..719f2d3c33 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -38,6 +38,9 @@ But in contrast to the layers described in the other sections are not readily gr ```@docs Maxout SkipConnection +GroupedConvolutions +ChannelShuffle +ShuffledGroupedConvolutions ``` ## Activation Functions diff --git a/src/Flux.jl b/src/Flux.jl index 9969b32346..6ef9fe8fda 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,8 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64 + SkipConnection, GroupedConvolutions, ChannelShuffle, ShuffledGroupedConvolutions, + params, fmap, cpu, gpu, f32, f64 include("optimise/Optimise.jl") using .Optimise diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a46520818..0a9a1f5369 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -214,7 +214,7 @@ size(sm(x)) == (5, 5, 11, 10) """ struct SkipConnection layers - connection #user can pass arbitrary connections here, such as (a,b) -> a + b + connection # user can pass arbitrary connections here, such as (a,b) -> a + b end @functor SkipConnection @@ -226,3 +226,275 @@ end function Base.show(io::IO, b::SkipConnection) print(io, "SkipConnection(", b.layers, ", ", b.connection, ")") end + +""" + GroupedConvolutions(connection, paths, split) + +Creates a group of convolutions from a set of layers or chains of consecutive layers. +Proposed in [Alexnet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networ). + +The connection function will combine the results of each paths, to give the final output. +If split is false, each path acts on all feature maps of the input. +If split is true, the feature maps of the input are evenly distributed across all paths. + +Data should be stored in WHCN order (width, height, # channels, # batches). +In other words, a 100×100 RGB image would be a `100×100×3×1` array, +and a batch of 50 would be a `100×100×3×50` array. + +The names of the variables are consistent accross all examples: +`i` stands for input, +`a` and `b`, `c`, and `d` are `Chains`, +`g` represents a `GroupedConvolutions`, +`s` is a `SkipConnection`, +and `o` is the output. + +Examples A, B, and C show how to use grouped convolutions in practice for [ResNeXt](https://arxiv.org/abs/1611.05431). +Batch Normalization and ReLU activations are left out for simplicity. + +**Example A**: ResNeXt block without splitting. +``` +i = randn(7,7,256,16) +a() = Chain(Conv((1,1), 256=>4 , pad=(0,0)), + Conv((3,3), 4 =>4 , pad=(1,1)), + Conv((1,1), 4 =>256, pad=(0,0))) +g = GroupedConvolutions(+, [a() for _ = 1:32]..., split=false) +s = SkipConnection(g, +) +o = s(i) +``` + +**Example B**: ResNeXt block without splitting and early concatenation. +``` +i = randn(7,7,256,16) +a() = Chain(Conv((1,1), 256=>4, pad=(0,0)), + Conv((3,3), 4 =>4, pad=(1,1))) +b = Chain(GroupedConvolutions((results...) -> cat(results..., dims=3), [a() for _ = 1:32]..., split=false), + Conv((1,1), 128=>256, pad=(0,0))) +s = SkipConnection(b, +) +o = s(i) +``` + +**Example C**: ResNeXt block with splitting (and concatentation). +``` +i = randn(7,7,256,16) +b = Chain(Conv((1,1), 256=>128, pad=(0,0)), + GroupedConvolutions((results...) -> cat(results..., dims=3), [Conv((3,3), 4=>4, pad=(1,1)) for _ = 1:32]..., split=true), + Conv((1,1), 128=>256, pad=(0,0))) +s = SkipConnection(b, +) +o = s(i) +``` + +Example D shows how to use grouped convolutions in practice for [Inception v1 (GoogLeNet)](https://research.google/pubs/pub43022/). + +**Example D**: Inception v1 (GoogLeNet) block +(The numbers used in this example come from Inception block 3a.) +``` +i = randn(28,28,192,16) +a = Conv( (1,1), 192=>64, pad=(0,0), relu) +b = Chain(Conv( (1,1), 192=>96, pad=(0,0), relu), Conv((3,3), 96 =>128, pad=(1,1), relu)) +c = Chain(Conv( (1,1), 192=>16, pad=(0,0), relu), Conv((5,5), 16 =>32 , pad=(2,2), relu)) +d = Chain(MaxPool((3,3), stride=1, pad=(1,1) ), Conv((1,1), 192=>32 , pad=(0,0), relu)) +g = GroupedConvolutions((results...) -> cat(results..., dims=3), a, b, c, d, split=false) +o = g(i) +``` +""" +struct GroupedConvolutions{T<:Tuple} + connection # user can pass arbitrary connections here, such as (a,b) -> a + b + paths::T + split::Bool + + function GroupedConvolutions(connection, paths...; split::Bool=false) + npaths = size(paths, 1) + npaths > 1 || error("the number of paths (", npaths, ") is not greater than 1") + new{typeof(paths)}(connection, paths, split) + end + + function GroupedConvolutions(connection, paths::Tuple; split::Bool=false) + npaths = size(paths, 1) + npaths > 1 || error("the number of paths (", npaths, ") is not greater than 1") + new{Tuple}(connection, paths, split) + end +end + +@functor GroupedConvolutions + +function (group::GroupedConvolutions)(input) + # get input size + w::Int64, h::Int64, c::Int64, n::Int64 = size(input) + # number of feature maps in input + nmaps::Int64 = c + # number of paths of the GroupedConvolution + npaths::Int64 = size(group.paths, 1) + + if group.split == true + # distributes the feature maps of the input over the paths + # throw error if number of feature maps not divisible by number of paths + mod(nmaps, npaths) == 0 || error("the number of feature maps in the input (", nmaps, ") is not divisible by the number of paths of the GroupedConvolution (", npaths, ")") + + # number of maps per path + nmaps_per_path::Int64 = div(nmaps, npaths) + + # calculate the output for the grouped convolutions + group.connection([path(input[:,:,_start_index(path_index, nmaps_per_path):_stop_index(path_index, nmaps_per_path),:]) for (path_index, path) in enumerate(group.paths)]...) + else + # uses the complete input for each path + group.connection([path(input) for (path) in group.paths]...) + end +end + +# calculates the start index of the feature maps for a path +_start_index(path_index::Int64, nmaps_per_path::Int64) = (path_index - 1) * nmaps_per_path + 1 +# calculates the stop index of the feature maps for a path +_stop_index(path_index::Int64, nmaps_per_path::Int64) = (path_index) * nmaps_per_path + +function Base.show(io::IO, group::GroupedConvolutions) + print(io, "GroupedConvolutions(", group.connection, ", ", group.paths, ", split=", group.split, ")") +end + +""" + ChannelShuffle(ngroups) + +Creates a layer that shuffles feature maps by each time taking the first channel of each group. +Proposed in [ShuffleNet](https://arxiv.org/abs/1707.01083). + +The number of channels in the input must be divisible by the square of the number of groups. +(Each group must have a multiple of the number of groups channels.) + +Examples of channel shuffling: +* (4 channels, 2 groups) **ab,cd -> ac,bd** +* (8 channels, 2 groups) **abcd,efgh -> aebf,cgdh** +* (16 channels, 2 groups) **abcdefgh,ijklmnop -> aibjckdl,emfngohp** +* (9 channels, 3 groups) **abc,def,ghi -> adg,beh,cfi** +* (16 channels, 4 groups) **abcd,efgh,ijkl,mnop -> aeim,bfjn,cgko,dhlp** + +Data should be stored in WHCN order (width, height, # channels, # batches). +In other words, a 100×100 RGB image would be a `100×100×3×1` array, +and a batch of 50 would be a `100×100×3×50` array. + +The names of the variables are consistent accross all examples: +`i` stands for input, +`a`, `b`, and `c` are `Chains`, +`g` represents a `GroupedConvolutions`, +`s` is a `SkipConnection`, +and `o` is the output. + +Examples A and B show how to use channel shuffling in practice for [ShuffleNet](https://arxiv.org/abs/1707.01083). +Batch Normalization and ReLU activations are left out for simplicity. + +**Example A**: ShuffleNet v1 unit with stride=1. +(The numbers used in this example come from stage 2 and using 2 groups.) +``` +i = randn(28,28,200,16) +c = Chain(GroupedConvolutions(+, [Conv((1,1), 200=>64, pad=(0,0)) for _ in 1:2]..., split=false), + ChannelShuffle(2), + DepthwiseConv((3,3), 64=>64, pad=(1,1), stride=(1,1)), + GroupedConvolutions(+, [Conv((1,1), 64=>200, pad=(0,0)) for _ in 1:2]..., split=false)) +s = SkipConnection(c, +) +o = s(i) +``` + +**Example B**: ShuffleNet v1 unit with stride=2. +(The numbers used in this example come from stage 2 and using 2 groups.) +This example shows the use of nested grouped convolutions as well. +``` +i = randn(56,56,24,16) +a = MeanPool((3,3), pad=(1,1), stride=(2,2)) +b = Chain(GroupedConvolutions(+, [Conv((1,1), 24=>64 , pad=(0,0)) for _ in 1:2]..., split=false), + ChannelShuffle(2), + DepthwiseConv((3,3), 64=>64, pad=(1,1), stride=(2,2)), + GroupedConvolutions(+, [Conv((1,1), 64=>176, pad=(0,0)) for _ in 1:2]..., split=false)) +g = GroupedConvolutions((results...) -> cat(results..., dims=3), a, b, split=false) +o = g(i) +``` +""" +struct ChannelShuffle + ngroups::Int + + function ChannelShuffle(ngroups::Int) + ngroups > 1 || error("the number of groups (", ngroups, ") is not greater than 1") + new(ngroups) + end +end + +@functor ChannelShuffle + +function (shuffle::ChannelShuffle)(input) + # get input size + w::Int64, h::Int64, c::Int64, n::Int64 = size(input) + # number of feature maps in input + nmaps::Int64 = c + # number of groups of the ChannelShuffle + ngroups::Int64 = shuffle.ngroups + # throw error if number of feature maps not divisible by number of paths + mod(nmaps, ngroups*ngroups) == 0 || error("the number of feature maps in the input (", nmaps, ") is not divisible by the square of the number of groups of the ChannelShuffle (", ngroups*ngroups, ")") + + # number of maps per group + nmaps_per_group::Int64 = div(nmaps, ngroups) + + # split up dimension of feature maps + input = reshape(input, (w, h, nmaps_per_group, ngroups, n)) + # transpose the newly created dimensions, but not recursively + input = permutedims(input, [1, 2, 4, 3, 5]) + # flatten the result back to the original dimensions + reshape(input, (w, h, c, n)) +end + +function Base.show(io::IO, shuffle::ChannelShuffle) + print(io, "ChannelShuffle(", shuffle.ngroups, ")") +end + +""" + ShuffledGroupedConvolutions(connection, paths, split) + ShuffledGroupedConvolutions(group, shuffle) + +A wrapper around a subsequent `GroupedConvolutions` and `ChannelShuffle`. +Takes the number of paths in the grouped convolutions to be the number of groups in the channel shuffling operation. + +Data should be stored in WHCN order (width, height, # channels, # batches). +In other words, a 100×100 RGB image would be a `100×100×3×1` array, +and a batch of 50 would be a `100×100×3×50` array. + +The names of the variables are consistent accross all examples: +`i` stands for input, +`a` and `b` are `Chains`, +`g` represents a `GroupedConvolutions`, +`s` is a `SkipConnection`, +and `o` is the output. + +Example A shows how to use shuffled grouped convolutions in practice for [ShuffleNet](https://arxiv.org/abs/1707.01083). +Batch Normalization and ReLU activations are left out for simplicity. + +**Example A**: ShuffleNet v1 unit with stride=1. +(The numbers used in this example come from stage 2 and using 2 groups.) +``` +i = randn(28, 28, 200, 16) +c = Chain(ShuffledGroupedConvolutions(+, [Conv((1,1), 200=>64, pad=(0,0)) for _ in 1:2]..., split=false), + #ShuffledGroupedConvolutions(GroupedConvolutions(+, [Conv((1,1), 200=>64, pad=(0,0)) for _ in 1:2]..., split=false), + # ChannelShuffle(2)), + DepthwiseConv((3,3), 64=>64, pad=(1,1), stride=(1,1)), + GroupedConvolutions(+, [Conv((1,1), 64=>200, pad=(0,0)) for _ in 1:2]..., split=false)) +s = SkipConnection(c, +) +o = s(i) +``` +""" +struct ShuffledGroupedConvolutions + group::GroupedConvolutions + shuffle::ChannelShuffle + + function ShuffledGroupedConvolutions(group::GroupedConvolutions, shuffle::ChannelShuffle) + shuffle.ngroups == size(group.paths, 1) || error("the number of groups in the ChannelShuffle layer (", shuffle.ngroups, ") is not equal to the number of paths in the GroupedConvolutions (", size(group.paths, 1), ")") + new(group, shuffle) + end + + ShuffledGroupedConvolutions(connection, paths...; split::Bool=false) = new(GroupedConvolutions(connection, paths, split=split), ChannelShuffle(size(paths, 1))) + ShuffledGroupedConvolutions(connection, paths::Tuple; split::Bool=false) = new(GroupedConvolutions(connection, paths, split=split), ChannelShuffle(size(paths, 1))) +end + +@functor ShuffledGroupedConvolutions + +function (shuffled::ShuffledGroupedConvolutions)(input) + shuffled.shuffle(shuffled.group(input)) +end + +function Base.show(io::IO, shuffled::ShuffledGroupedConvolutions) + print(io, shuffled.group, ", ", shuffled.shuffle) +end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0ff1776db8..1f78294d47 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -28,7 +28,7 @@ import Flux: activations end @testset "Dense" begin - @test length(Dense(10, 5)(randn(10))) == 5 + @test length(Dense(10, 5)(randn(10))) == 5 @test_throws DimensionMismatch Dense(10, 5)(randn(1)) @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting @@ -92,4 +92,339 @@ import Flux: activations @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) end end + + @testset "GroupedConvolutions" begin + input256 = randn(7, 7, 256, 16) + + @testset "constructor" begin + path1 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path2 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path3 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path4 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + + # the number of paths is not greater than 1 + @test_throws ErrorException GroupedConvolutions(+) + @test_throws ErrorException GroupedConvolutions(+, split=false) + @test_throws ErrorException GroupedConvolutions(+, split=true) + @test_throws ErrorException GroupedConvolutions(+, path1) + @test_throws ErrorException GroupedConvolutions(+, path1, split=false) + @test_throws ErrorException GroupedConvolutions(+, path1, split=true) + @test_throws ErrorException GroupedConvolutions(+, ()) + @test_throws ErrorException GroupedConvolutions(+, (), split=false) + @test_throws ErrorException GroupedConvolutions(+, (), split=true) + @test_throws ErrorException GroupedConvolutions(+, (path1)) + @test_throws ErrorException GroupedConvolutions(+, (path1), split=false) + @test_throws ErrorException GroupedConvolutions(+, (path1), split=true) + + # tuple + group3 = GroupedConvolutions(+, (path1, path2, path3), split=true) + @test size(group3.paths, 1) == 3 + @test group3.split == true + group4 = GroupedConvolutions(+, (path1, path2, path3, path4), split=true) + @test size(group4.paths, 1) == 4 + @test group4.split ==true + + # varargs + group3 = GroupedConvolutions(+, path1, path2, path3, split=true) + @test size(group3.paths, 1) == 3 + @test group3.split == true + group4 = GroupedConvolutions(+, path1, path2, path3, path4, split=true) + @test size(group4.paths, 1) == 4 + @test group4.split ==true + end + + @testset "sum split" begin + path1 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path2 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path3 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path4 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + result1 = path1(input256[:,:,1:64,:]) + result2 = path2(input256[:,:,65:128,:]) + result3 = path3(input256[:,:,129:192,:]) + result4 = path4(input256[:,:,193:256,:]) + group3 = GroupedConvolutions(+, (path1, path2, path3), split=true) + group4 = GroupedConvolutions(+, (path1, path2, path3, path4), split=true) + + # summation for 3 paths + # the number of feature maps in the input (256) is not divisible by the number of paths of the GroupedConvolution (3) + @test_throws ErrorException group3(input256) + + # summation for 4 paths + result = group4(input256) + @test size(result) == size(input256) + @test result == result1 + result2 + result3 + result4 + end + + @testset "sum no split" begin + path1 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path2 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path3 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path4 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + result1 = path1(input256) + result2 = path2(input256) + result3 = path3(input256) + result4 = path4(input256) + group3 = GroupedConvolutions(+, (path1, path2, path3)) + group4 = GroupedConvolutions(+, (path1, path2, path3, path4)) + + # summation for 3 paths + # does not throw exception anymore + result = group3(input256) + @test size(result) == size(input256) + @test result == result1 + result2 + result3 + + # summation for 4 paths + result = group4(input256) + @test size(result) == size(input256) + @test result == result1 + result2 + result3 + result4 + end + + @testset "cat split" begin + path1 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + path2 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + path3 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + path4 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + result1 = path1(input256[:,:,1:64,:]) + result2 = path2(input256[:,:,65:128,:]) + result3 = path3(input256[:,:,129:192,:]) + result4 = path4(input256[:,:,193:256,:]) + group3 = GroupedConvolutions((a,b,c) -> cat(a, b, c, dims=3), (path1, path2, path3), split=true) + group4 = GroupedConvolutions((a,b,c,d) -> cat(a, b, c, d, dims=3), (path1, path2, path3, path4), split=true) + result = group4(input256) + + # concatenation for 3 paths + # the number of feature maps in the input (256) is not divisible by the number of paths of the GroupedConvolution (3) + @test_throws ErrorException group3(input256) + + # concatenation for 4 paths + @test size(result) == (7, 7, 4*4, 16) + @test result == cat(result1, result2, result3, result4, dims=3) + end + + @testset "cat no split" begin + path1 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + path2 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + path3 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + path4 = Chain( + Conv((1,1), 256 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)) + ) + result1 = path1(input256) + result2 = path2(input256) + result3 = path3(input256) + result4 = path4(input256) + group3 = GroupedConvolutions((a,b,c) -> cat(a, b, c, dims=3), (path1, path2, path3)) + group4 = GroupedConvolutions((a,b,c,d) -> cat(a, b, c, d, dims=3), (path1, path2, path3, path4)) + + # concatenation for 3 paths + # does not throw exception anymore + result = group3(input256) + @test size(result) == (7, 7, 3*4, 16) + @test result == cat(result1, result2, result3, dims=3) + + # concatenation for 4 paths + result = group4(input256) + @test size(result) == (7, 7, 4*4, 16) + @test result == cat(result1, result2, result3, result4, dims=3) + end + + @testset "mixed paths" begin + path1 = Conv((1,1), 128=>64, pad=(0, 0), stride=(1, 1)) + path2 = Chain( + Conv((1,1), 128 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 64, pad=(0, 0), stride=(1, 1)) + ) + result1 = path1(input256[:,:,1:128,:]) + result2 = path2(input256[:,:,129:256,:]) + group2 = GroupedConvolutions(+, (path1, path2), split=true) + result = group2(input256) + + # summation for 2 different paths + @test size(result) == (7, 7, 64, 16) + @test result == result1 + result2 + end + end + + @testset "ChannelShuffle" begin + @testset "constructor" begin + # the number of groups is not greater than 1 + @test_throws ErrorException ChannelShuffle(-1) + @test_throws ErrorException ChannelShuffle(0) + @test_throws ErrorException ChannelShuffle(1) + end + + @testset "channel shuffling" begin + input3 = reshape(collect(1:1*1*3*1),(1,1,3,1)) + input4 = reshape(collect(1:1*1*4*1),(1,1,4,1)) + input8 = reshape(collect(1:1*1*8*1),(1,1,8,1)) + input9 = reshape(collect(1:1*1*9*1),(1,1,9,1)) + input16 = reshape(collect(1:1*1*16*1),(1,1,16,1)) + input256 = reshape(collect(1:7*7*256*16),(7,7,256,16)) + shuffle2 = ChannelShuffle(2) + shuffle3 = ChannelShuffle(3) + shuffle4 = ChannelShuffle(4) + shuffle8 = ChannelShuffle(8) + + # the number of feature maps in the input is not divisible by the square of the number of groups of the ChannelShuffle + @test_throws ErrorException shuffle3(input3) + @test_throws ErrorException shuffle3(input4) + @test_throws ErrorException shuffle4(input8) + + # ab,cd -> ac,bd (2 groups) + # 12,34 -> 13,24 (2 groups) + @test shuffle2(input4)[1,1,:,1] == [1, 3, 2, 4] + + # abcd,efgh -> aebf,cgdh (2 groups) + # 12434,5678 -> 1526,3748 (2 groups) + @test shuffle2(input8)[1,1,:,1] == [1, 5, 2, 6, 3, 7, 4, 8] + + # abcdefgh,ijklmnop -> aibjckdl,emfngohp (2 groups) + # 12345678,9... -> 192.3.4.,5.6.7.8. (2 groups) + @test shuffle2(input16)[1,1,:,1] == [1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16] + + # abc,def,ghi -> adg,beh,cfi (3 groups) + # 123,456,789 -> 147,258,369 (3 groups) + @test shuffle3(input9)[1,1,:,1] == [1, 4, 7, 2, 5, 8, 3, 6, 9] + + # abcd,efgh,ijkl,mnop -> aeim,bfjn,cgko,dhlp (4 groups) + # 1234,5678,9... -> 159.,26..,37..,48.. (4 groups) + @test shuffle4(input16)[1,1,:,1] == [1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16] + + # bigger arrays + @test size(shuffle8(input256)) == size(input256) + end + end + + @testset "ShuffledGroupedConvolutions" begin + input256 = randn(7, 7, 256, 16) + path1 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path2 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path3 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + path4 = Chain( + Conv((1,1), 64 => 4, pad=(0, 0), stride=(1, 1)), + Conv((3,3), 4 => 4, pad=(1, 1), stride=(1, 1)), + Conv((1,1), 4 => 256, pad=(0, 0), stride=(1, 1)) + ) + group4 = GroupedConvolutions(+, (path1, path2, path3, path4), split=true) + result1 = path1(input256[:,:,1:64,:]) + result2 = path2(input256[:,:,65:128,:]) + result3 = path3(input256[:,:,129:192,:]) + result4 = path4(input256[:,:,193:256,:]) + shuffle3 = ChannelShuffle(3) + shuffle4 = ChannelShuffle(4) + + @testset "constructor" begin + # the number of groups in the ChannelShuffle layer (3) is not equal to the number of paths in the GroupedConvolutions (4) + @test_throws ErrorException ShuffledGroupedConvolutions(group4, shuffle3) + + # tuple + shuffled_group3 = ShuffledGroupedConvolutions(+, (path1, path2, path3), split=true) + @test size(shuffled_group3.group.paths, 1) == 3 + @test shuffled_group3.group.split == true + @test shuffled_group3.shuffle.ngroups == 3 + shuffled_group4 = ShuffledGroupedConvolutions(+, (path1, path2, path3, path4), split=true) + @test size(shuffled_group4.group.paths, 1) == 4 + @test shuffled_group4.group.split == true + @test shuffled_group4.shuffle.ngroups == 4 + + # varargs + shuffled_group3 = ShuffledGroupedConvolutions(+, path1, path2, path3, split=true) + @test size(shuffled_group3.group.paths, 1) == 3 + @test shuffled_group3.group.split == true + @test shuffled_group3.shuffle.ngroups == 3 + shuffled_group4 = ShuffledGroupedConvolutions(+, path1, path2, path3, path4, split=true) + @test size(shuffled_group4.group.paths, 1) == 4 + @test shuffled_group4.group.split == true + @test shuffled_group4.shuffle.ngroups == 4 + end + + @testset "shuffled grouped convolutions" begin + shuffle_group4 = ShuffledGroupedConvolutions(group4, shuffle4) + result = shuffle_group4(input256) + @test size(result) == size(input256) + @test result == shuffle4(group4(input256)) + end + end end