Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the SkipConnection layer and constructor #446

Merged
merged 5 commits into from
Jun 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# v0.9.0
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.

# v0.8.0

Expand Down
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ But in contrast to the layers described in the other sections are not readily gr

```@docs
Maxout
SkipConnection
```

## Activation Functions
Expand Down
3 changes: 2 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection,
params, mapleaves, cpu, gpu, f32, f64

@reexport using NNlib
Expand Down
33 changes: 33 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,36 @@ end
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

"""
SkipConnection(layers...)

Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied callable.

`SkipConnection` requires the output dimension to be the same as the input.

A 'ResNet'-type skip-connection with identity shortcut would simply be
```julia
SkipConnection(layer, (a,b) -> a + b)
```
"""

struct SkipConnection
layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@treelike SkipConnection

function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input)
end

function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(")
join(io, b.layers, ", ")
print(io, ")")
end
12 changes: 12 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,16 @@ import Flux: activations
@test length(ps) == 8 #4 alts, each with weight and bias
end
end

@testset "SkipConnection" begin
@testset "zero sum" begin
input = randn(10, 10, 10, 10)
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
end

@testset "concat size" begin
input = randn(10, 2)
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
end
end
end