diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f51f8911aa..f5058a30f9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -27,6 +27,10 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)), julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` + +For large models, there is a special type-unstable path which can reduce compilation +times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`. +This feature is somewhat experimental, beware! """ struct Chain{T} layers::T @@ -36,6 +40,7 @@ struct Chain{T} isempty(kw) && return new{Tuple{}}(()) new{typeof(values(kw))}(values(kw)) end + Chain(xs::AbstractVector) = new{typeof(xs)}(xs) # unstable path, to help compile times end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, @@ -53,6 +58,13 @@ end applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x) +function applychain(layers::AbstractVector, x) + for f in layers + x = f(x) + end + x +end + Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...) @@ -64,6 +76,7 @@ function Base.show(io::IO, c::Chain) end _show_layers(io, layers::Tuple) = join(io, layers, ", ") _show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") +_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]")) # This is a temporary and naive implementation # it might be replaced in the future for better performance diff --git a/src/layers/show.jl b/src/layers/show.jl index 791d2511ca..819ef82d75 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -14,11 +14,12 @@ for T in [ end function _big_show(io::IO, obj, indent::Int=0, name=nothing) + pre, post = obj isa Chain{<:AbstractVector} ? ("[", "]") : ("", "") children = trainable(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") + println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(", pre) if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) # then we insert names -- can this be done more generically? for k in Base.keys(obj) @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end if indent == 0 # i.e. this is the outermost container - print(io, ")") + print(io, post, ")") _big_finale(io, obj) else - println(io, " "^indent, "),") + println(io, " "^indent, post, "),") end end end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0b9f340142..128bcb5367 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -26,6 +26,8 @@ import Flux: activations @test m[1:2] == m @test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name + + @test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers end @testset "Activations" begin @@ -274,6 +276,10 @@ end m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) @test Zygote.hessian_dual(sum∘m1, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1, [1,2,3]) + m1v = Chain([m1[1], m1[2]]) # vector of layers + @test Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_dual(sum∘m1, [1,2,3]) + @test_broken Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1v, [1,2,3]) + # NNlib's softmax gradient writes in-place m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax) @test_broken Zygote.hessian_dual(sum∘m2, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m2, [1,2,3]) @@ -284,3 +290,16 @@ end @test_broken Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) end +@testset "gradients of Chain{Vector}" begin + m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + m1v = Chain([m1[1], m1[2]]) + @test sum(length, params(m1)) == sum(length, params(m1v)) + + x1 = randn(Float32,3,5) + @test m1(x1) ≈ m1v(x1) + y1 = rand(Bool,2,5) + g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1)) + g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v)) + @test g1[m1[1].weight] ≈ g1v[m1v[1].weight] + @test g1[m1[2].bias] ≈ g1v[m1v[2].bias] +end