Skip to content

Commit

Permalink
more... the dimensionmismatch bug is not here
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 11, 2022
1 parent 756b450 commit d95a147
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
17 changes: 11 additions & 6 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
const NoT = NoTangent()

base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure))
base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version

"""
destructure(model) -> vector, reconstructor
Expand Down Expand Up @@ -55,21 +58,24 @@ Base.length(re::Restructure) = re.length

# This flattens a model, and returns a web of offsets for later use:
function _flatten(x)
isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
push!(arrays, vec(y))
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
end
reduce(vcat, arrays), off, len[]
end

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)

function ChainRulesCore.rrule(::typeof(_flatten), x)
flat, off, len = _flatten(x)
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT))
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT))
(flat, off, len), _flatten_back
end

Expand All @@ -92,7 +98,7 @@ function _trainable_biwalk(f, x, aux)
end

function _trainmap(f, ch, tr, aux)
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)??
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
isnothing(t) ? c : f(t, a)
end
end
Expand Down Expand Up @@ -121,7 +127,7 @@ ChainRulesCore.@non_differentiable _zero(x)
# This is the gradient of model reconstruction, accumulating duplicates:
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), dx)
dx′, _ = functor(typeof(x), base(dx))
off′, _ = functor(typeof(x), off)
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
flat
Expand All @@ -134,7 +140,6 @@ _grad!(x, dx::Zero, off, flat::AbstractVector) = dx
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity

function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
println("grad! fwd ", length(flat))
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
_grad!(x, dx, off, flat), _grad_back
end
31 changes: 22 additions & 9 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
m1 = collect(1:3.0)
m2 = (collect(1:3.0), collect(4:6.0))
m3 = (x = m1, y = sin, z = collect(4:6.0))
m4 = (x = m1, y = m1, z = collect(4:6.0))
m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
m6 = (a = m1, b = [4.0 + im], c = m1)
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
Expand Down Expand Up @@ -72,13 +72,24 @@ end
@test g8[3] == [[10.0]]

@testset "second derivative" begin
@test_broken gradient([1,2,3.0]) do v
@test gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
end[1] [8,16,24]
# With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx:
# off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ...
# until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing
# With Zygote, instead:
# dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),)

@test gradient([1,2,3.0]) do v
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
end[1] == [378, 378, 378]

@test_skip gradient([1,2,3.0]) do v
sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1])
end
@test_broken gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
end[1] [8,16,24]
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
# Diffractor error in perform_optic_transform
end
end

Expand Down Expand Up @@ -109,15 +120,17 @@ end
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]

@testset "second derivative" begin
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
@test_broken gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
end[1] [8,16,24,0,0,0]
# This fixes it!
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
# with Zygote, which can be fixed by:
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
@test_skip gradient(collect(1:6.0)) do y

@test_broken gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1])
end[1]
end[1] [0,0,0,32,40,48]
# Not fixed by this:
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
end
end
Expand Down

0 comments on commit d95a147

Please sign in to comment.