-
-
Notifications
You must be signed in to change notification settings - Fork 613
Replace unrolled foldl
used to evaluate Chain
with a better one
#1809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6918b0e
00ba124
a9bbb0c
585043d
fcb09da
d99d7ac
4dfd551
f60da1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,12 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)), | |
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) | ||
true | ||
``` | ||
|
||
For large models, there is a special type-unstable path which can reduce compilation | ||
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`. | ||
This feature is somewhat experimental, beware! | ||
""" | ||
struct Chain{T<:Union{Tuple, NamedTuple}} | ||
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}} | ||
layers::T | ||
end | ||
|
||
|
@@ -44,10 +48,22 @@ end | |
|
||
@functor Chain | ||
|
||
applychain(::Tuple{}, x) = x | ||
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) | ||
(c::Chain)(x) = applychain(c.layers, x) | ||
|
||
@generated function applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N} | ||
symbols = vcat(:x, [gensym() for _ in 1:N]) | ||
calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N] | ||
Expr(:block, calls...) | ||
end | ||
|
||
applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x) | ||
|
||
(c::Chain)(x) = applychain(Tuple(c.layers), x) | ||
function applychain(layers::AbstractVector, x) # type-unstable path, helps compile times | ||
for f in layers | ||
x = f(x) | ||
end | ||
x | ||
end | ||
|
||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) | ||
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = | ||
|
@@ -60,6 +76,7 @@ function Base.show(io::IO, c::Chain) | |
end | ||
_show_layers(io, layers::Tuple) = join(io, layers, ", ") | ||
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") | ||
_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]")) | ||
|
||
# This is a temporary and naive implementation | ||
# it might be replaced in the future for better performance | ||
Comment on lines
81
to
82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, in addition to a hand-written There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, you could likewise do |
||
|
Uh oh!
There was an error while loading. Please reload this page.