Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the SkipConnection layer and constructor #446

Merged
merged 5 commits into from
Jun 5, 2019
Merged

Added the SkipConnection layer and constructor #446

merged 5 commits into from
Jun 5, 2019

Conversation

bhvieira
Copy link
Contributor

I added a DenseBlock constructor, which allows one to train DenseNets (you can train ResNets and MixNets with this as well, only need change the connection, which is concatenation for DenseNets).

Disclaimer: I created the block for a 3D U-Net, so the assumption here is that whatever layer is inside the block, its output has the same spatial dimension (i.e. all array dimensions excluding the channel and minibatch dimensions) as the input, otherwise the connection wouldn't match. I'm not sure this matches the topology of every DenseNet there is out there, but I suppose this is a good starting point.

No tests yet, will add them as the PR evolve.

I'm open to suggestions! :)

@bhvieira
Copy link
Contributor Author

Oh, there's also a nomenclature issue. When people think of DenseBlocks they are mostly thinking about densely connected convolutional blocks. I suppose we could call this something else, like SkipConnection or ShortcutBlock. Thoughts?

@MikeInnes
Copy link
Member

MikeInnes commented Oct 24, 2018

This seems like a good implementation to me but I wonder if it isn't a bit specific. Maybe others can chime in if this seems like it'd be useful to them. Or perhaps there's some precedent in other frameworks?

I like the idea of calling it SkipConnection, and maybe generalising a bit by making the combiner customisable; e.g. it could be a plain + by default and have documentation for passing in concatenation (assuming + is more common).

src/layers/basic.jl Outdated Show resolved Hide resolved
@bhvieira
Copy link
Contributor Author

bhvieira commented Oct 24, 2018

This seems like a good implementation to me but I wonder if it isn't a bit specific. Maybe others can chime in if this seems like it'd be useful to them. Or perhaps there's some precedent in other frameworks?

Skip connections and deep supervision are quite commonplace nowadays, but I get the point. I'm not sure about other frameworks, I'll investigate.

Though Flux does have a bit of domain-specific methods already, skip connections are everywhere in computer vision and segmentation in general.

@bhvieira bhvieira changed the title Added the DenseBlock layer and constructor Added the SkipConnection layer and constructor Oct 24, 2018
@r3tex
Copy link
Contributor

r3tex commented Feb 25, 2019

It would be nice if this could be written so that new arrays don't have to be allocated on every forward pass. Models used for real-time inference would benefit.

@MikeInnes
Copy link
Member

I'm on board with adding this, with the tests and docs you mentioned.

I'd like to remove the special casing of cat, though; probably better to just make that argument explicit.

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 26, 2019

I'd like to remove the special casing of cat, though; probably better to just make that argument explicit.

Oh yeah, you're completely right about that @MikeInnes . When I started with Flux I had just picked Julia, and wasn't that familiar with it at all. I can do it better now. Would it be better to do it with multiple-dispatch (still casing, but improved in my eyes) or a simple keyword-argument defaulting to connection = cat?

@MikeInnes
Copy link
Member

I'm thinking just keep it as a positional argument but without a default, so you can do SkipConnection(m, (a, b) -> cat(a, b, 3)) (or define cat3 or something in your script to make it shorter). I think it's better to be explicit here than to have the meaning of SkipConnection be tied to how we define image models (which might change with autobatching, for example).

@bhvieira
Copy link
Contributor Author

bhvieira commented Apr 9, 2019

It would be nice if this could be written so that new arrays don't have to be allocated on every forward pass. Models used for real-time inference would benefit.

@r3tex Thanks for the suggestion! My Julia-fu is weak, is the new, streamlined, version enough? I could look into types as well I guess.

Here's proof gradients propagate correctly:

m = Chain(Dense(1, 10), Dense(10,1))
M = SkipConnection(m, (A,B) -> A .+ B)

using Flux.Tracker
using Flux.Tracker: grad

x = randn(1,20)
y = ones(1,20)
loss(x,y) = sum((M(x) - y).^2)

grads = Tracker.gradient(() -> loss(x, y), params(M))
grads[M.layers[1].W]
grads[M.layers[1].b]

@bhvieira
Copy link
Contributor Author

bhvieira commented Apr 9, 2019

Also, is the current show method ok? I haven't invested much time on it actually.

@bhvieira
Copy link
Contributor Author

Added some tests, still superficial, could add more if asked. Is it okay @MikeInnes ?

@MikeInnes
Copy link
Member

This is looking nice! Couple small things: remove the Function annotation since not all callables are functions (you might want to use a Flux model here, for example). SkipConnection should also be listed in the layers docs somewhere. Also, would be great to have a news item for this.

Added missing export

Corrected channel placement

Dimension 4 cannot be assumed to always be the Channel dimension

Deprecation of `treelike`

Code now makes use of `@treelike` macro instead of the deprecated `treelike` function (it worked on my end because I'm on Julia 0.7, while Julia 1.0 deprecated stuff)

Update basic.jl

Renaming to SkipConnection

* Update Flux.jl

* Update basic.jl

Updated `SkipConnection` with a `connection` field

I'm pretty sure I broke something now, but this PR should follow along these lines `cat` needs special treatment (the user can declare his own `concatenate` connection, but I foresee it's going to be used often so we can simply define special treatment)

Forgot to remove some rebasing text

Forgot to remove some more rebasing text

Removed local copy and default cat method from the function calls

Adjusted some more types for inference, could improve on this as well

Re-placed some left-over spaces
@bhvieira
Copy link
Contributor Author

@MikeInnes I added a line to docs/src/models/layers.md, under ## Other General Purpose Layers, is that what you meant?

@MikeInnes
Copy link
Member

Looks perfect, thanks @bhvieira!

bors r+

bors bot added a commit that referenced this pull request Jun 5, 2019
446: Added the SkipConnection layer and constructor r=MikeInnes a=bhvieira

I added a DenseBlock constructor, which allows one to train DenseNets (you can train ResNets and MixNets with this as well, only need change the connection, which is concatenation for DenseNets).

Disclaimer: I created the block for a 3D U-Net, so the assumption here is that whatever layer is inside the block, its output has the same spatial dimension (i.e. all array dimensions excluding the channel and minibatch dimensions) as the input, otherwise the connection wouldn't match. I'm not sure this matches the topology of every DenseNet there is out there, but I suppose this is a good starting point.

No tests yet, will add them as the PR evolve.

I'm open to suggestions! :)


Co-authored-by: Bruno Hebling Vieira <bruno.hebling.vieira@usp.br>
Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
@bors
Copy link
Contributor

bors bot commented Jun 5, 2019

Build succeeded

@bors bors bot merged commit b980758 into FluxML:master Jun 5, 2019
@bhvieira bhvieira deleted the DenseBlock branch June 6, 2019 13:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants