Skip to content

Commit

Permalink
Merge #446
Browse files Browse the repository at this point in the history
446: Added the SkipConnection layer and constructor r=MikeInnes a=bhvieira

I added a DenseBlock constructor, which allows one to train DenseNets (you can train ResNets and MixNets with this as well, only need change the connection, which is concatenation for DenseNets).

Disclaimer: I created the block for a 3D U-Net, so the assumption here is that whatever layer is inside the block, its output has the same spatial dimension (i.e. all array dimensions excluding the channel and minibatch dimensions) as the input, otherwise the connection wouldn't match. I'm not sure this matches the topology of every DenseNet there is out there, but I suppose this is a good starting point.

No tests yet, will add them as the PR evolve.

I'm open to suggestions! :)


Co-authored-by: Bruno Hebling Vieira <bruno.hebling.vieira@usp.br>
Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
  • Loading branch information
3 people committed Jun 5, 2019
2 parents 8ee6af1 + b980758 commit 1902c0e
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# v0.9.0
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.

# v0.8.0

Expand Down
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ But in contrast to the layers described in the other sections are not readily gr

```@docs
Maxout
SkipConnection
```

## Activation Functions
Expand Down
3 changes: 2 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection,
params, mapleaves, cpu, gpu, f32, f64

@reexport using NNlib
Expand Down
33 changes: 33 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,36 @@ end
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

"""
SkipConnection(layers...)
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied callable.
`SkipConnection` requires the output dimension to be the same as the input.
A 'ResNet'-type skip-connection with identity shortcut would simply be
```julia
SkipConnection(layer, (a,b) -> a + b)
```
"""

struct SkipConnection
layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@treelike SkipConnection

function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input)
end

function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(")
join(io, b.layers, ", ")
print(io, ")")
end
12 changes: 12 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,16 @@ import Flux: activations
@test length(ps) == 8 #4 alts, each with weight and bias
end
end

@testset "SkipConnection" begin
@testset "zero sum" begin
input = randn(10, 10, 10, 10)
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
end

@testset "concat size" begin
input = randn(10, 2)
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
end
end
end

0 comments on commit 1902c0e

Please sign in to comment.