Skip to content

Commit

Permalink
# This is a combination of 4 commits.
Browse files Browse the repository at this point in the history
# This is the 1st commit message:

added GroupedConvolutions

# This is the commit message #2:

Working implementation of channel shuffling

# This is the commit message #3:

grouped convolutions can now act on the whole input

# This is the commit message #4:

updated documentation
  • Loading branch information
gartangh committed Jan 18, 2020
1 parent d1edd9b commit f27480b
Show file tree
Hide file tree
Showing 4 changed files with 614 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ But in contrast to the layers described in the other sections are not readily gr
```@docs
Maxout
SkipConnection
GroupedConvolutions
ChannelShuffle
ShuffledGroupedConvolutions
```

## Activation Functions
Expand Down
3 changes: 2 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ export gradient

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64
SkipConnection, GroupedConvolutions, ChannelShuffle, ShuffledGroupedConvolutions,
params, fmap, cpu, gpu, f32, f64

include("optimise/Optimise.jl")
using .Optimise
Expand Down
274 changes: 273 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ size(sm(x)) == (5, 5, 11, 10)
"""
struct SkipConnection
layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
connection # user can pass arbitrary connections here, such as (a,b) -> a + b
end

@functor SkipConnection
Expand All @@ -226,3 +226,275 @@ end
function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
end

"""
GroupedConvolutions(connection, paths, split)
Creates a group of convolutions from a set of layers or chains of consecutive layers.
Proposed in [Alexnet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networ).
The connection function will combine the results of each paths, to give the final output.
If split is false, each path acts on all feature maps of the input.
If split is true, the feature maps of the input are evenly distributed across all paths.
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
The names of the variables are consistent accross all examples:
`i` stands for input,
`a` and `b`, `c`, and `d` are `Chains`,
`g` represents a `GroupedConvolutions`,
`s` is a `SkipConnection`,
and `o` is the output.
Examples A, B, and C show how to use grouped convolutions in practice for [ResNeXt](https://arxiv.org/abs/1611.05431).
Batch Normalization and ReLU activations are left out for simplicity.
**Example A**: ResNeXt block without splitting.
```
i = randn(7,7,256,16)
a() = Chain(Conv((1,1), 256=>4 , pad=(0,0)),
Conv((3,3), 4 =>4 , pad=(1,1)),
Conv((1,1), 4 =>256, pad=(0,0)))
g = GroupedConvolutions(+, [a() for _ = 1:32]..., split=false)
s = SkipConnection(g, +)
o = s(i)
```
**Example B**: ResNeXt block without splitting and early concatenation.
```
i = randn(7,7,256,16)
a() = Chain(Conv((1,1), 256=>4, pad=(0,0)),
Conv((3,3), 4 =>4, pad=(1,1)))
b = Chain(GroupedConvolutions((results...) -> cat(results..., dims=3), [a() for _ = 1:32]..., split=false),
Conv((1,1), 128=>256, pad=(0,0)))
s = SkipConnection(b, +)
o = s(i)
```
**Example C**: ResNeXt block with splitting (and concatentation).
```
i = randn(7,7,256,16)
b = Chain(Conv((1,1), 256=>128, pad=(0,0)),
GroupedConvolutions((results...) -> cat(results..., dims=3), [Conv((3,3), 4=>4, pad=(1,1)) for _ = 1:32]..., split=true),
Conv((1,1), 128=>256, pad=(0,0)))
s = SkipConnection(b, +)
o = s(i)
```
Example D shows how to use grouped convolutions in practice for [Inception v1 (GoogLeNet)](https://research.google/pubs/pub43022/).
**Example D**: Inception v1 (GoogLeNet) block
(The numbers used in this example come from Inception block 3a.)
```
i = randn(28,28,192,16)
a = Conv( (1,1), 192=>64, pad=(0,0), relu)
b = Chain(Conv( (1,1), 192=>96, pad=(0,0), relu), Conv((3,3), 96 =>128, pad=(1,1), relu))
c = Chain(Conv( (1,1), 192=>16, pad=(0,0), relu), Conv((5,5), 16 =>32 , pad=(2,2), relu))
d = Chain(MaxPool((3,3), stride=1, pad=(1,1) ), Conv((1,1), 192=>32 , pad=(0,0), relu))
g = GroupedConvolutions((results...) -> cat(results..., dims=3), a, b, c, d, split=false)
o = g(i)
```
"""
struct GroupedConvolutions{T<:Tuple}
connection # user can pass arbitrary connections here, such as (a,b) -> a + b
paths::T
split::Bool

function GroupedConvolutions(connection, paths...; split::Bool=false)
npaths = size(paths, 1)
npaths > 1 || error("the number of paths (", npaths, ") is not greater than 1")
new{typeof(paths)}(connection, paths, split)
end

function GroupedConvolutions(connection, paths::Tuple; split::Bool=false)
npaths = size(paths, 1)
npaths > 1 || error("the number of paths (", npaths, ") is not greater than 1")
new{Tuple}(connection, paths, split)
end
end

@functor GroupedConvolutions

function (group::GroupedConvolutions)(input)
# get input size
w::Int64, h::Int64, c::Int64, n::Int64 = size(input)
# number of feature maps in input
nmaps::Int64 = c
# number of paths of the GroupedConvolution
npaths::Int64 = size(group.paths, 1)

if group.split == true
# distributes the feature maps of the input over the paths
# throw error if number of feature maps not divisible by number of paths
mod(nmaps, npaths) == 0 || error("the number of feature maps in the input (", nmaps, ") is not divisible by the number of paths of the GroupedConvolution (", npaths, ")")

# number of maps per path
nmaps_per_path::Int64 = div(nmaps, npaths)

# calculate the output for the grouped convolutions
group.connection([path(input[:,:,_start_index(path_index, nmaps_per_path):_stop_index(path_index, nmaps_per_path),:]) for (path_index, path) in enumerate(group.paths)]...)
else
# uses the complete input for each path
group.connection([path(input) for (path) in group.paths]...)
end
end

# calculates the start index of the feature maps for a path
_start_index(path_index::Int64, nmaps_per_path::Int64) = (path_index - 1) * nmaps_per_path + 1
# calculates the stop index of the feature maps for a path
_stop_index(path_index::Int64, nmaps_per_path::Int64) = (path_index) * nmaps_per_path

function Base.show(io::IO, group::GroupedConvolutions)
print(io, "GroupedConvolutions(", group.connection, ", ", group.paths, ", split=", group.split, ")")
end

"""
ChannelShuffle(ngroups)
Creates a layer that shuffles feature maps by each time taking the first channel of each group.
Proposed in [ShuffleNet](https://arxiv.org/abs/1707.01083).
The number of channels in the input must be divisible by the square of the number of groups.
(Each group must have a multiple of the number of groups channels.)
Examples of channel shuffling:
* (4 channels, 2 groups) **ab,cd -> ac,bd**
* (8 channels, 2 groups) **abcd,efgh -> aebf,cgdh**
* (16 channels, 2 groups) **abcdefgh,ijklmnop -> aibjckdl,emfngohp**
* (9 channels, 3 groups) **abc,def,ghi -> adg,beh,cfi**
* (16 channels, 4 groups) **abcd,efgh,ijkl,mnop -> aeim,bfjn,cgko,dhlp**
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
The names of the variables are consistent accross all examples:
`i` stands for input,
`a`, `b`, and `c` are `Chains`,
`g` represents a `GroupedConvolutions`,
`s` is a `SkipConnection`,
and `o` is the output.
Examples A and B show how to use channel shuffling in practice for [ShuffleNet](https://arxiv.org/abs/1707.01083).
Batch Normalization and ReLU activations are left out for simplicity.
**Example A**: ShuffleNet v1 unit with stride=1.
(The numbers used in this example come from stage 2 and using 2 groups.)
```
i = randn(28,28,200,16)
c = Chain(GroupedConvolutions(+, [Conv((1,1), 200=>64, pad=(0,0)) for _ in 1:2]..., split=false),
ChannelShuffle(2),
DepthwiseConv((3,3), 64=>64, pad=(1,1), stride=(1,1)),
GroupedConvolutions(+, [Conv((1,1), 64=>200, pad=(0,0)) for _ in 1:2]..., split=false))
s = SkipConnection(c, +)
o = s(i)
```
**Example B**: ShuffleNet v1 unit with stride=2.
(The numbers used in this example come from stage 2 and using 2 groups.)
This example shows the use of nested grouped convolutions as well.
```
i = randn(56,56,24,16)
a = MeanPool((3,3), pad=(1,1), stride=(2,2))
b = Chain(GroupedConvolutions(+, [Conv((1,1), 24=>64 , pad=(0,0)) for _ in 1:2]..., split=false),
ChannelShuffle(2),
DepthwiseConv((3,3), 64=>64, pad=(1,1), stride=(2,2)),
GroupedConvolutions(+, [Conv((1,1), 64=>176, pad=(0,0)) for _ in 1:2]..., split=false))
g = GroupedConvolutions((results...) -> cat(results..., dims=3), a, b, split=false)
o = g(i)
```
"""
struct ChannelShuffle
ngroups::Int

function ChannelShuffle(ngroups::Int)
ngroups > 1 || error("the number of groups (", ngroups, ") is not greater than 1")
new(ngroups)
end
end

@functor ChannelShuffle

function (shuffle::ChannelShuffle)(input)
# get input size
w::Int64, h::Int64, c::Int64, n::Int64 = size(input)
# number of feature maps in input
nmaps::Int64 = c
# number of groups of the ChannelShuffle
ngroups::Int64 = shuffle.ngroups
# throw error if number of feature maps not divisible by number of paths
mod(nmaps, ngroups*ngroups) == 0 || error("the number of feature maps in the input (", nmaps, ") is not divisible by the square of the number of groups of the ChannelShuffle (", ngroups*ngroups, ")")

# number of maps per group
nmaps_per_group::Int64 = div(nmaps, ngroups)

# split up dimension of feature maps
input = reshape(input, (w, h, nmaps_per_group, ngroups, n))
# transpose the newly created dimensions, but not recursively
input = permutedims(input, [1, 2, 4, 3, 5])
# flatten the result back to the original dimensions
reshape(input, (w, h, c, n))
end

function Base.show(io::IO, shuffle::ChannelShuffle)
print(io, "ChannelShuffle(", shuffle.ngroups, ")")
end

"""
ShuffledGroupedConvolutions(connection, paths, split)
ShuffledGroupedConvolutions(group, shuffle)
A wrapper around a subsequent `GroupedConvolutions` and `ChannelShuffle`.
Takes the number of paths in the grouped convolutions to be the number of groups in the channel shuffling operation.
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
The names of the variables are consistent accross all examples:
`i` stands for input,
`a` and `b` are `Chains`,
`g` represents a `GroupedConvolutions`,
`s` is a `SkipConnection`,
and `o` is the output.
Example A shows how to use shuffled grouped convolutions in practice for [ShuffleNet](https://arxiv.org/abs/1707.01083).
Batch Normalization and ReLU activations are left out for simplicity.
**Example A**: ShuffleNet v1 unit with stride=1.
(The numbers used in this example come from stage 2 and using 2 groups.)
```
i = randn(28, 28, 200, 16)
c = Chain(ShuffledGroupedConvolutions(+, [Conv((1,1), 200=>64, pad=(0,0)) for _ in 1:2]..., split=false),
#ShuffledGroupedConvolutions(GroupedConvolutions(+, [Conv((1,1), 200=>64, pad=(0,0)) for _ in 1:2]..., split=false),
# ChannelShuffle(2)),
DepthwiseConv((3,3), 64=>64, pad=(1,1), stride=(1,1)),
GroupedConvolutions(+, [Conv((1,1), 64=>200, pad=(0,0)) for _ in 1:2]..., split=false))
s = SkipConnection(c, +)
o = s(i)
```
"""
struct ShuffledGroupedConvolutions
group::GroupedConvolutions
shuffle::ChannelShuffle

function ShuffledGroupedConvolutions(group::GroupedConvolutions, shuffle::ChannelShuffle)
shuffle.ngroups == size(group.paths, 1) || error("the number of groups in the ChannelShuffle layer (", shuffle.ngroups, ") is not equal to the number of paths in the GroupedConvolutions (", size(group.paths, 1), ")")
new(group, shuffle)
end

ShuffledGroupedConvolutions(connection, paths...; split::Bool=false) = new(GroupedConvolutions(connection, paths, split=split), ChannelShuffle(size(paths, 1)))
ShuffledGroupedConvolutions(connection, paths::Tuple; split::Bool=false) = new(GroupedConvolutions(connection, paths, split=split), ChannelShuffle(size(paths, 1)))
end

@functor ShuffledGroupedConvolutions

function (shuffled::ShuffledGroupedConvolutions)(input)
shuffled.shuffle(shuffled.group(input))
end

function Base.show(io::IO, shuffled::ShuffledGroupedConvolutions)
print(io, shuffled.group, ", ", shuffled.shuffle)
end
Loading

0 comments on commit f27480b

Please sign in to comment.