Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add destructure, take II #54

Merged
merged 13 commits into from
Feb 14, 2022
Merged

Add destructure, take II #54

merged 13 commits into from
Feb 14, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 6, 2022

This adds a destructure function, like Flux's, which should handle awkward cases. Reqiures FluxML/Functors.jl#37 .

Alternative to #40 . This seems like a tidier approach, although not quite as short and elegant as it was before I started adding tests. The key idea is that, on the first walk over the model to flatten it, you can make a tree of vector offsets, which simplifies the reconstruction step and the gradients. The gradient of reconstruction isn't an fmap walk, but because it already knows the offsets, it does not care if the walks' orders don't match.

Should work with numbers too if isnumeric is widened to allow them. Should work with mixed element types too, promote for the vector, project back for the reconstruction.

The reason to put it here not in Functors is that this package must already depend on ChainRulesCore, and that this builds in trainable deeply enough to make having another version without it a pain. And, Functors at present doesn't have isnumeric.

Closes #40, closes FluxML/Functors.jl#31 if it can.

src/destructure.jl Outdated Show resolved Hide resolved
Comment on lines 38 to 66
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
push!(arrays, vec(y))
o = len[]
len[] = o + length(y)
o
end
Copy link
Member

@ToucheSir ToucheSir Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic, but this (and gamma!) is a concrete example of where something like FluxML/Functors.jl#32 or fmapreduce would help. Instead of writing to a ref, you'd just carry an accumulated total through the traversal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier commits, FWIW, just append!ed to a vector, and uses its length instead. This is faster, and delays working out the promoted type instead of needing another basically fmapreduce pass. A fancy way to pass the integer forwards might avoid the Ref, but we don't want to do pairwise vcat array-by-array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was referring to just the total instead of using carried state for the actual array of arrays (or single array). I assume the single array path was nixed because of not knowing the container type without traversing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, getting the type needed another walk: 5a7bfc8#diff-2fc6059e247338b5ac149900b866865ae69bdf3693d9ce37cef19f230ddb8e30L104
(And forgetting to make that walk non-differentiable gave some surprising bugs...)

Comment on lines 26 to 50
struct Restucture{T,S}
model::T
offsets::S
length::Int
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really happy this is being reified. The current implementation in Flux tries to dance around having to explicitly represent the state, but that's a big part of why it's so inflexible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main reason to replace this:

function destructure(x)
   flat, off, len = alpha(x)
   flat, v -> beta(x, off, v; len)
 end

with 7 more lines was that the anonymous function's type contains the huge offset struct's type, so it fills your screen...

Glad you don't object though! This struct is very much internal still, of course.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, that'll be a great QoL change too.

Copy link
Member

@darsnack darsnack Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also make it possible to port some of the speedy stuff (if we want to) in DiffEqFlux for FastChain etc. easier too. We can consider adding

(re::Restructure)(x, flat) = re(flat)(x)

for that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No downsides to that.

One question here is whether we want overwriting versions. Should there be a destructure! whose re writes back into the same model? And if we want (c::Chain)(x, p::Vector) to call something like that, but never make the vector, can these share code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking a bit around that in light of FluxML/Flux.jl#1855 and FluxML/Flux.jl#1809 (comment), but it's not clear to me how it could be done somewhat generically within Flux's existing design. This would be a lot easier if layers didn't own their trainable parameters, Flax-style.

@darsnack
Copy link
Member

darsnack commented Feb 7, 2022

Just a +1 comment from me. Brian already covered anything I would have.

Comment on lines +122 to +128
@testset "second derivative" begin
@test_broken gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
end[1] ≈ [8,16,24,0,0,0]
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
# with Zygote, which can be fixed by:
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the remaining question on this PR is whether and how much to care about 2nd derivatives. Some work, some don't. I convinced myself there is no bug in the basic logic. But in the details of when to wrap what in a Tangent, or unwrap it for Zygote, there might be bugs, here or upstream.

If we want to be pedantic we could make all 2nd derivatives an error, rather than risk any being wrong. Or a warning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least a warning sounds good to me.

Copy link
Member Author

@mcabbott mcabbott Feb 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

All warnings are maxlog=3, so as not to be too annoying if something does actually work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to go?


# These are only needed for 2nd derivatives:
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
@warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given these were completely busted most of the time before, I don't think we need to apologize so profusely 😆

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh. May as well be friendly!

Also, I think the point is that they ought to work, the structure does allow for them. Just it has bugs right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants