-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Testing new apply interface for Flux.Chain (#5)
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
- Loading branch information
Showing
4 changed files
with
157 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
|
||
import Flux: ChainRulesCore | ||
# Some experiments with chain to start removing the need for recur to be mutable. | ||
# As per the conversation in the recurrent network rework issue. | ||
|
||
# Main difference between this and the _applychain function is we return a new chain | ||
# with the internal state modified as well as the output of applying x to the chain. | ||
function apply(chain::Flux.Chain, x) | ||
layers, out = _apply(chain.layers, x) | ||
Flux.Chain(layers), out | ||
end | ||
|
||
function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS} | ||
layers, out = _apply(Tuple(layers), x) | ||
NamedTuple{NMS}(layers), out | ||
end | ||
|
||
function _scan(layers::AbstractVector, x) | ||
new_layers = typeof(layers)(undef, length(layers)) | ||
for (idx, f) in enumerate(layers) | ||
new_layers[idx], x = _apply(f, x) | ||
end | ||
new_layers, x | ||
end | ||
|
||
# Reverse rule for _scan | ||
# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl | ||
function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x) | ||
duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer | ||
out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input) | ||
end | ||
outs = map(first, duo) | ||
backs = map(last, duo) | ||
|
||
function _scan_pullback(dy) | ||
multi = accumulate(reverse(backs); init=(nothing, dy)) do (_, delta), back | ||
dapply, dlayer, din = back(delta) | ||
return dapply, (dlayer, din) | ||
end | ||
layergrads = reverse(map(first, multi)) | ||
xgrad = last(multi[end]) | ||
return (ChainRulesCore.NoTangent(), layergrads, xgrad) | ||
end | ||
return (map(first, outs), last(outs[end])), _scan_pullback | ||
end | ||
|
||
function _apply(layers::AbstractVector, x) # type-unstable path, helps compile times | ||
_scan(layers, x) | ||
end | ||
|
||
# Generated function returns a tuple of args and the last output of the network. | ||
@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N} | ||
x_symbols = vcat(:x, [gensym() for _ in 1:N]) | ||
l_symbols = [gensym() for _ in 1:N] | ||
calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N] | ||
push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end]))) | ||
Expr(:block, calls...) | ||
end | ||
|
||
_apply(layer, x) = layer, layer(x) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Checking if the two grad structures are equal. Simplifies tests below. | ||
function _grads_equal(grads1, grads2) | ||
if length(keys(grads1)) != length(keys(grads2)) | ||
return false | ||
end | ||
ret = true | ||
for weights in keys(grads1) | ||
if grads1[weights] isa AbstractArray | ||
ret = ret && all(grads1[weights] .== grads2[weights]) | ||
elseif isnothing(grads1[weights]) | ||
ret = ret && isnothing(grads2[weights]) | ||
else | ||
throw("Grad returned type $(typeof(grads1[weights]))") | ||
end | ||
end | ||
return ret | ||
end | ||
|
||
@testset "Applying the Chain!" begin | ||
@testset "Forward pass" begin | ||
x = rand(Float32, 3, 1) | ||
l1 = Flux.Dense(3, 4) | ||
l2 = Flux.Dense(4, 1) | ||
truth = l2(l1(x)) | ||
|
||
t_c = Flux.Chain(l1, l2) # tuple Chain | ||
new_t_c, out = Fluxperimental.apply(t_c, x) | ||
@test new_t_c[1] === l1 && new_t_c[2] === l2 | ||
@test all(out .== truth) | ||
|
||
|
||
nt_c = Flux.Chain(l1=l1, l2=l2) # namedtuple Chain | ||
new_nt_c, out = Fluxperimental.apply(nt_c, x) | ||
@test new_nt_c[:l1] === l1 && new_nt_c[:l2] === l2 | ||
@test all(out .== truth) | ||
|
||
|
||
v_c = Flux.Chain([l1, l2]) # vector Chain | ||
new_v_c, out = Fluxperimental.apply(v_c, x) | ||
@test new_v_c.layers[1] === l1 && new_v_c.layers[2] === l2 | ||
@test all(out .== truth) | ||
end # @testset "Forward Pass" | ||
|
||
@testset "Backward pass" begin | ||
x = rand(Float32, 3, 1) | ||
l1 = Flux.Dense(3, 4) | ||
l2 = Flux.Dense(4, 1) | ||
|
||
@test begin # Test Tuple Chain Gradients | ||
t_c = Flux.Chain(l1, l2) # tuple Chain | ||
grads_truth = Flux.gradient(Flux.params(t_c)) do | ||
sum(t_c(x)) | ||
end | ||
|
||
grads_tuple = Flux.gradient(Flux.params(t_c)) do | ||
sum(Fluxperimental.apply(t_c, x)[end]) | ||
end | ||
|
||
_grads_equal(grads_tuple, grads_truth) | ||
end | ||
|
||
@test begin # Test Named Tuple's Gradients | ||
nt_c = Flux.Chain(l1=l1, l2=l2) # named tuple Chain | ||
grads_truth = Flux.gradient(Flux.params(nt_c)) do | ||
sum(nt_c(x)) | ||
end | ||
|
||
grads_tuple = Flux.gradient(Flux.params(nt_c)) do | ||
sum(Fluxperimental.apply(nt_c, x)[end]) | ||
end | ||
|
||
_grads_equal(grads_tuple, grads_truth) | ||
end | ||
|
||
@test begin # Test Vector Gradient | ||
c = Flux.Chain([l1, l2]) # named tuple Chain | ||
grads_truth = Flux.gradient(Flux.params(c)) do | ||
sum(c(x)) | ||
end | ||
|
||
grads_tuple = Flux.gradient(Flux.params(c)) do | ||
sum(Fluxperimental.apply(c, x)[end]) | ||
end | ||
|
||
_grads_equal(grads_tuple, grads_truth) | ||
end | ||
end # @testset "Backward Pass" | ||
end # @testset "Applying the Chain!" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters