Skip to content

Commit

Permalink
Testing new apply interface for Flux.Chain (#5)
Browse files Browse the repository at this point in the history
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
  • Loading branch information
mkschleg and ToucheSir authored Mar 26, 2023
1 parent d917e17 commit 19a9205
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ export Split, Join
include("train.jl")
export shinkansen!


include("chain.jl")

include("compact.jl")

end # module Fluxperimental
62 changes: 62 additions & 0 deletions src/chain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

import Flux: ChainRulesCore
# Some experiments with chain to start removing the need for recur to be mutable.
# As per the conversation in the recurrent network rework issue.

# Main difference between this and the _applychain function is we return a new chain
# with the internal state modified as well as the output of applying x to the chain.
function apply(chain::Flux.Chain, x)
layers, out = _apply(chain.layers, x)
Flux.Chain(layers), out
end

function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS}
layers, out = _apply(Tuple(layers), x)
NamedTuple{NMS}(layers), out
end

function _scan(layers::AbstractVector, x)
new_layers = typeof(layers)(undef, length(layers))
for (idx, f) in enumerate(layers)
new_layers[idx], x = _apply(f, x)
end
new_layers, x
end

# Reverse rule for _scan
# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl
function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x)
duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer
out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input)
end
outs = map(first, duo)
backs = map(last, duo)

function _scan_pullback(dy)
multi = accumulate(reverse(backs); init=(nothing, dy)) do (_, delta), back
dapply, dlayer, din = back(delta)
return dapply, (dlayer, din)
end
layergrads = reverse(map(first, multi))
xgrad = last(multi[end])
return (ChainRulesCore.NoTangent(), layergrads, xgrad)
end
return (map(first, outs), last(outs[end])), _scan_pullback
end

function _apply(layers::AbstractVector, x) # type-unstable path, helps compile times
_scan(layers, x)
end

# Generated function returns a tuple of args and the last output of the network.
@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
x_symbols = vcat(:x, [gensym() for _ in 1:N])
l_symbols = [gensym() for _ in 1:N]
calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N]
push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end])))
Expr(:block, calls...)
end

_apply(layer, x) = layer, layer(x)


88 changes: 88 additions & 0 deletions test/chain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Checking if the two grad structures are equal. Simplifies tests below.
function _grads_equal(grads1, grads2)
if length(keys(grads1)) != length(keys(grads2))
return false
end
ret = true
for weights in keys(grads1)
if grads1[weights] isa AbstractArray
ret = ret && all(grads1[weights] .== grads2[weights])
elseif isnothing(grads1[weights])
ret = ret && isnothing(grads2[weights])
else
throw("Grad returned type $(typeof(grads1[weights]))")
end
end
return ret
end

@testset "Applying the Chain!" begin
@testset "Forward pass" begin
x = rand(Float32, 3, 1)
l1 = Flux.Dense(3, 4)
l2 = Flux.Dense(4, 1)
truth = l2(l1(x))

t_c = Flux.Chain(l1, l2) # tuple Chain
new_t_c, out = Fluxperimental.apply(t_c, x)
@test new_t_c[1] === l1 && new_t_c[2] === l2
@test all(out .== truth)


nt_c = Flux.Chain(l1=l1, l2=l2) # namedtuple Chain
new_nt_c, out = Fluxperimental.apply(nt_c, x)
@test new_nt_c[:l1] === l1 && new_nt_c[:l2] === l2
@test all(out .== truth)


v_c = Flux.Chain([l1, l2]) # vector Chain
new_v_c, out = Fluxperimental.apply(v_c, x)
@test new_v_c.layers[1] === l1 && new_v_c.layers[2] === l2
@test all(out .== truth)
end # @testset "Forward Pass"

@testset "Backward pass" begin
x = rand(Float32, 3, 1)
l1 = Flux.Dense(3, 4)
l2 = Flux.Dense(4, 1)

@test begin # Test Tuple Chain Gradients
t_c = Flux.Chain(l1, l2) # tuple Chain
grads_truth = Flux.gradient(Flux.params(t_c)) do
sum(t_c(x))
end

grads_tuple = Flux.gradient(Flux.params(t_c)) do
sum(Fluxperimental.apply(t_c, x)[end])
end

_grads_equal(grads_tuple, grads_truth)
end

@test begin # Test Named Tuple's Gradients
nt_c = Flux.Chain(l1=l1, l2=l2) # named tuple Chain
grads_truth = Flux.gradient(Flux.params(nt_c)) do
sum(nt_c(x))
end

grads_tuple = Flux.gradient(Flux.params(nt_c)) do
sum(Fluxperimental.apply(nt_c, x)[end])
end

_grads_equal(grads_tuple, grads_truth)
end

@test begin # Test Vector Gradient
c = Flux.Chain([l1, l2]) # named tuple Chain
grads_truth = Flux.gradient(Flux.params(c)) do
sum(c(x))
end

grads_tuple = Flux.gradient(Flux.params(c)) do
sum(Fluxperimental.apply(c, x)[end])
end

_grads_equal(grads_tuple, grads_truth)
end
end # @testset "Backward Pass"
end # @testset "Applying the Chain!"
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@ using Flux, Fluxperimental

@testset "Fluxperimental.jl" begin
include("split_join.jl")

include("chain.jl")

include("compact.jl")

end

0 comments on commit 19a9205

Please sign in to comment.