Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing new apply interface for Flux.Chain #5

Merged
merged 8 commits into from
Mar 26, 2023
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ export Split, Join
include("train.jl")
export shinkansen!


include("chain.jl")

include("compact.jl")

end # module Fluxperimental
49 changes: 49 additions & 0 deletions src/chain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

import Flux: ChainRulesCore
# Some experiments with chain to start removing the need for recur to be mutable.
# As per the conversation in the recurrent network rework issue.

# Main difference between this and the _applychain function is we return a new chain
# with the internal state modified as well as the output of applying x to the chain.
apply(chain::Flux.Chain, x) = begin
mkschleg marked this conversation as resolved.
Show resolved Hide resolved
layers, out = _apply(chain.layers, x)
Flux.Chain(layers), out
end

_apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS} = begin
mkschleg marked this conversation as resolved.
Show resolved Hide resolved
layers, out = _apply(Tuple(layers), x)
NamedTuple{NMS}(layers), out
end

function _scan(layers::AbstractVector, x)
new_layers = typeof(layers)(undef, length(layers))
for (idx, f) in enumerate(layers)
new_layers[idx], x = _apply_to_layer(f, x)
end
new_layers, x
end

function ChainRulesCore.rrule(::typeof(_scan), layers, x)
function _scan_pullback(dy)
throw("_scan Pullback not implemented")
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent())
mkschleg marked this conversation as resolved.
Show resolved Hide resolved
end
return _scan(layers, x), _scan_pullback
end

function _apply(layers::AbstractVector, x) # type-unstable path, helps compile times
_scan(layers, x)
end

# Generated function returns a tuple of args and the last output of the network.
@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
x_symbols = vcat(:x, [gensym() for _ in 1:N])
l_symbols = [gensym() for _ in 1:N]
calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply_to_layer(layers[$i], $(x_symbols[i]))) for i in 1:N]
push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end])))
Expr(:block, calls...)
end

_apply_to_layer(layer, x) = layer, layer(x)


85 changes: 85 additions & 0 deletions test/chain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

import Flux, Fluxperimental

mkschleg marked this conversation as resolved.
Show resolved Hide resolved
function _grads_equal(grads1, grads2)

if length(keys(grads1)) != length(keys(grads2))
return false
end
ret = true
for weights in keys(grads1)
if grads1[weights] isa AbstractArray
ret = ret && all(grads1[weights] .== grads2[weights])
elseif isnothing(grads1[weights])
ret = ret && isnothing(grads2[weights])
else
throw("Grad returned type $(typeof(grads1[weights]))")
end
end
return ret
end

@testset "Applying the Chain!" begin

@testset "Forward pass" begin
x = rand(Float32, 3, 1)
l1 = Flux.Dense(3, 4)
l2 = Flux.Dense(4, 1)
truth = l2(l1(x))

t_c = Flux.Chain(l1, l2) # tuple Chain
new_t_c, out = Fluxperimental.apply(t_c, x)
@test new_t_c[1] === l1 && new_t_c[2] === l2
@test all(out .== truth)


nt_c = Flux.Chain(l1=l1, l2=l2) # namedtuple Chain
new_nt_c, out = Fluxperimental.apply(nt_c, x)
@test new_nt_c[:l1] === l1 && new_nt_c[:l2] === l2
@test all(out .== truth)


v_c = Flux.Chain([l1, l2]) # vector Chain
new_v_c, out = Fluxperimental.apply(v_c, x)
@test new_v_c.layers[1] === l1 && new_v_c.layers[2] === l2
@test all(out .== truth)
end

@testset "Backward pass" begin
x = rand(Float32, 3, 1)
l1 = Flux.Dense(3, 4)
l2 = Flux.Dense(4, 1)
# truth = l2(l1(x))


@test begin
t_c = Flux.Chain(l1, l2) # tuple Chain
grads_truth = Flux.gradient(Flux.params(t_c)) do
sum(t_c(x))
end

grads_tuple = Flux.gradient(Flux.params(t_c)) do
sum(Fluxperimental.apply(t_c, x)[end])
end

_grads_equal(grads_tuple, grads_truth)
end

# @test begin
# v_c = Flux.Chain([l1, l2]) # vector Chain
# grads_v_truth = Flux.gradient(Flux.params(v_c)) do
# sum(v_c(x))
# end
# grads_vector = Flux.gradient(Flux.params(v_c)) do
# sum(Fluxperimental.apply(v_c, x)[end])
# end

# _grads_equal(grads_vector, grads_v_truth)
# end skip=true



end


end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@ using Flux, Fluxperimental

@testset "Fluxperimental.jl" begin
include("split_join.jl")

include("chain.jl")

include("compact.jl")

end