Skip to content

Commit

Permalink
allow unstable Chain{Vector} too
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jan 11, 2022
1 parent 657e267 commit 4bc55de
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
13 changes: 13 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)),
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
For large models, there is a special type-unstable path which can reduce compilation
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
"""
struct Chain{T}
layers::T
Expand All @@ -36,6 +40,7 @@ struct Chain{T}
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
Chain(xs::AbstractVector) = new{typeof(xs)}(xs) # unstable path, to help compile times
end

@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Expand All @@ -53,6 +58,13 @@ end

applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x)

function applychain(layers::AbstractVector, x)
for f in layers
x = f(x)
end
x
end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
Expand All @@ -64,6 +76,7 @@ function Base.show(io::IO, c::Chain)
end
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")
_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]"))

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Expand Down
7 changes: 4 additions & 3 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = obj isa Chain{<:AbstractVector} ? ("[", "]") : ("", "")
children = trainable(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(", pre)
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
Expand All @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end
if indent == 0 # i.e. this is the outermost container
print(io, ")")
print(io, post, ")")
_big_finale(io, obj)
else
println(io, " "^indent, "),")
println(io, " "^indent, post, "),")
end
end
end
Expand Down
19 changes: 19 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import Flux: activations
@test m[1:2] == m

@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name

@test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers
end

@testset "Activations" begin
Expand Down Expand Up @@ -274,6 +276,10 @@ end
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
@test Zygote.hessian_dual(summ1, [1,2,3]) Zygote.hessian_reverse(summ1, [1,2,3])

m1v = Chain([m1[1], m1[2]]) # vector of layers
@test Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_dual(summ1, [1,2,3])
@test_broken Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_reverse(summ1v, [1,2,3])

# NNlib's softmax gradient writes in-place
m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax)
@test_broken Zygote.hessian_dual(summ2, [1,2,3]) Zygote.hessian_reverse(summ2, [1,2,3])
Expand All @@ -284,3 +290,16 @@ end
@test_broken Zygote.hessian_dual(summ3, x3) Zygote.hessian_reverse(summ3, x3)
end

@testset "gradients of Chain{Vector}" begin
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
m1v = Chain([m1[1], m1[2]])
@test sum(length, params(m1)) == sum(length, params(m1v))

x1 = randn(Float32,3,5)
@test m1(x1) m1v(x1)
y1 = rand(Bool,2,5)
g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1))
g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v))
@test g1[m1[1].weight] g1v[m1v[1].weight]
@test g1[m1[2].bias] g1v[m1v[2].bias]
end

0 comments on commit 4bc55de

Please sign in to comment.