diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 948d5ed..91438b0 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -8,6 +8,9 @@ export Split, Join include("train.jl") export shinkansen! + +include("chain.jl") + include("compact.jl") end # module Fluxperimental diff --git a/src/chain.jl b/src/chain.jl new file mode 100644 index 0000000..0095de3 --- /dev/null +++ b/src/chain.jl @@ -0,0 +1,62 @@ + +import Flux: ChainRulesCore +# Some experiments with chain to start removing the need for recur to be mutable. +# As per the conversation in the recurrent network rework issue. + +# Main difference between this and the _applychain function is we return a new chain +# with the internal state modified as well as the output of applying x to the chain. +function apply(chain::Flux.Chain, x) + layers, out = _apply(chain.layers, x) + Flux.Chain(layers), out +end + +function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS} + layers, out = _apply(Tuple(layers), x) + NamedTuple{NMS}(layers), out +end + +function _scan(layers::AbstractVector, x) + new_layers = typeof(layers)(undef, length(layers)) + for (idx, f) in enumerate(layers) + new_layers[idx], x = _apply(f, x) + end + new_layers, x +end + +# Reverse rule for _scan +# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl +function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x) + duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer + out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input) + end + outs = map(first, duo) + backs = map(last, duo) + + function _scan_pullback(dy) + multi = accumulate(reverse(backs); init=(nothing, dy)) do (_, delta), back + dapply, dlayer, din = back(delta) + return dapply, (dlayer, din) + end + layergrads = reverse(map(first, multi)) + xgrad = last(multi[end]) + return (ChainRulesCore.NoTangent(), layergrads, xgrad) + end + return (map(first, outs), last(outs[end])), _scan_pullback +end + +function _apply(layers::AbstractVector, x) # type-unstable path, helps compile times + _scan(layers, x) +end + +# Generated function returns a tuple of args and the last output of the network. +@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N} + x_symbols = vcat(:x, [gensym() for _ in 1:N]) + l_symbols = [gensym() for _ in 1:N] + calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N] + push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end]))) + Expr(:block, calls...) +end + +_apply(layer, x) = layer, layer(x) + + diff --git a/test/chain.jl b/test/chain.jl new file mode 100644 index 0000000..073173b --- /dev/null +++ b/test/chain.jl @@ -0,0 +1,88 @@ +# Checking if the two grad structures are equal. Simplifies tests below. +function _grads_equal(grads1, grads2) + if length(keys(grads1)) != length(keys(grads2)) + return false + end + ret = true + for weights in keys(grads1) + if grads1[weights] isa AbstractArray + ret = ret && all(grads1[weights] .== grads2[weights]) + elseif isnothing(grads1[weights]) + ret = ret && isnothing(grads2[weights]) + else + throw("Grad returned type $(typeof(grads1[weights]))") + end + end + return ret +end + +@testset "Applying the Chain!" begin + @testset "Forward pass" begin + x = rand(Float32, 3, 1) + l1 = Flux.Dense(3, 4) + l2 = Flux.Dense(4, 1) + truth = l2(l1(x)) + + t_c = Flux.Chain(l1, l2) # tuple Chain + new_t_c, out = Fluxperimental.apply(t_c, x) + @test new_t_c[1] === l1 && new_t_c[2] === l2 + @test all(out .== truth) + + + nt_c = Flux.Chain(l1=l1, l2=l2) # namedtuple Chain + new_nt_c, out = Fluxperimental.apply(nt_c, x) + @test new_nt_c[:l1] === l1 && new_nt_c[:l2] === l2 + @test all(out .== truth) + + + v_c = Flux.Chain([l1, l2]) # vector Chain + new_v_c, out = Fluxperimental.apply(v_c, x) + @test new_v_c.layers[1] === l1 && new_v_c.layers[2] === l2 + @test all(out .== truth) + end # @testset "Forward Pass" + + @testset "Backward pass" begin + x = rand(Float32, 3, 1) + l1 = Flux.Dense(3, 4) + l2 = Flux.Dense(4, 1) + + @test begin # Test Tuple Chain Gradients + t_c = Flux.Chain(l1, l2) # tuple Chain + grads_truth = Flux.gradient(Flux.params(t_c)) do + sum(t_c(x)) + end + + grads_tuple = Flux.gradient(Flux.params(t_c)) do + sum(Fluxperimental.apply(t_c, x)[end]) + end + + _grads_equal(grads_tuple, grads_truth) + end + + @test begin # Test Named Tuple's Gradients + nt_c = Flux.Chain(l1=l1, l2=l2) # named tuple Chain + grads_truth = Flux.gradient(Flux.params(nt_c)) do + sum(nt_c(x)) + end + + grads_tuple = Flux.gradient(Flux.params(nt_c)) do + sum(Fluxperimental.apply(nt_c, x)[end]) + end + + _grads_equal(grads_tuple, grads_truth) + end + + @test begin # Test Vector Gradient + c = Flux.Chain([l1, l2]) # named tuple Chain + grads_truth = Flux.gradient(Flux.params(c)) do + sum(c(x)) + end + + grads_tuple = Flux.gradient(Flux.params(c)) do + sum(Fluxperimental.apply(c, x)[end]) + end + + _grads_equal(grads_tuple, grads_truth) + end + end # @testset "Backward Pass" +end # @testset "Applying the Chain!" diff --git a/test/runtests.jl b/test/runtests.jl index 7c804c4..55315cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,5 +3,9 @@ using Flux, Fluxperimental @testset "Fluxperimental.jl" begin include("split_join.jl") + + include("chain.jl") + include("compact.jl") + end