Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thoughts on zipped tree traversal #2

Open
ToucheSir opened this issue Feb 11, 2022 · 8 comments
Open

Thoughts on zipped tree traversal #2

ToucheSir opened this issue Feb 11, 2022 · 8 comments

Comments

@ToucheSir
Copy link

I wanted to pick up from our discussion in FluxML/Functors.jl#32, in particular around your comments in FluxML/Functors.jl#32 (comment). I think one concrete example to work off of is a function like torch.nn.Module.load_state_dict 1. In other words, something that prefers but does not require similarly structured trees (it returns a list of discrepancies if they aren't).

I feel like this is absolutely something StructWalk could excel at, but I agree with your point that creating a one-size-fits-all mechanism would be difficult. My question would be whether there might be some extensible mechanism to handle this, just like how WalkStyle allows for highly customizable traversal behaviour over individual nodes.

Footnotes

  1. this came up in the discussion around restoring saved model weights, ref. @darsnack's post at https://github.com/FluxML/Flux.jl/issues/1027#issuecomment-1034226504.

@chengchingwen
Copy link
Owner

Maybe something like:

abstract type AlignedStyle{W<:WalkStyle} end

const Align = Union{AlignedStyle, Type{AlignedStyle}}

WalkStyle(::AlignedStyle) = WalkStyle(AlignedStyle)
WalkStyle(::Type{AlignedStyle}) = WalkStyle

function alignedstyle(style::Align, xs...)
    fns = map(xs) do x
        S = walkstyle(WalkStyle(style), x)
        _, fields = S
        isnontuple = length(S) <= 2 ? false : S[3]
        nchild = isnontuple ? sum(length, fields) : length(fields)
        (isnontuple ? Iterators.flatten(fields) : fields), nchild
    end
    fs = map(x->x[1], fns)
    n = minimum(x->x[2], fns)
    return identity, zip(fs...), n
end

zippedwalk(f, style::Align, inner_walk, xs...) = zippedwalk(f, f, style, inner_walk, xs...)
function zippedwalk(f, g, style::Align, inner_walk, xs...)
    T, a, n = alignedstyle(style, xs...)
    isleaf = iszero(n)
    if isleaf
        return f(xs)
    else
        return g(T(map(inner_walk, a)))
    end
end

zippedpostwalk(f, xs...) = zippedpostwalk(f, AlignedStyle, xs...)
zippedpostwalk(f, style::Align, xs...) = zippedwalk(f, style, x -> zippedpostwalk(f, style, x...), xs...)


julia> a
(x = 0, y = (w = 0, b = 0))

julia> b
(0.5, (0.3, 0.5))

julia> c
(1, 1, 1)

julia> StructWalk.zippedpostwalk(identity, a, b)
2-element Vector{Any}:
 (0, 0.5)
 [(0, 0.3), (0, 0.5)]

julia> StructWalk.zippedpostwalk(, a, b)^C

julia> StructWalk.zippedpostwalk(identity, a, c)
2-element Vector{Tuple{Any, Int64}}:
 (0, 1)
 ((w = 0, b = 0), 1)

I haven't think too much on the design, but basically the extensible mechanism here is coming from the alignedstyle which could probably dispatch on all kinds of argument combination if needed and fallback on zipping the children extracted with a linked WalkStyle. Things like fieldnames can be take into account with a different WalkStyle.

In other words, something that prefers but does not require similarly structured trees (it returns a list of discrepancies if they aren't).

Can you elaborate more on this? Assuming that we can load the "aligned"-able part, I wonder what do we want to do with the others? The torch.nn.Module.load_state_dict is actually quite narrow as it only take two argument which have a relatively clear hierarchy (state_dict is a flat dictionary with string key and models are all subclass of torch.nn.Module).

@ToucheSir
Copy link
Author

alignedstyle and co. look very interesting, thanks for writing this out! I played around and managed to get zippedpostwalk(+, a, b) working just by changing to f(xs...) in zippedwalk, do you foresee any issues with that change?

Can you elaborate more on this? Assuming that we can load the "aligned"-able part, I wonder what do we want to do with the others? The torch.nn.Module.load_state_dict is actually quite narrow as it only take two argument which have a relatively clear hierarchy (state_dict is a flat dictionary with string key and models are all subclass of torch.nn.Module).

My mistake, I'd misremembered. A better example would be how JAX frameworks deserialize weights or how https://github.com/deepmind/optax runs over multiple trees much like Optimisers.jl. In both cases, I believe they take the easy path and implicitly enforce structural similarity through flattening.

@chengchingwen
Copy link
Owner

do you foresee any issues with that change?

Should be fine, but some small concern:

  • returning tuple is somewhat more zip-like
  • In the zippedwalk, I use f and g for leaf function and node function respectively and g is default to f. I use f(xs) instead of f(xs...) is because I want to distinguish between called on leaves and called on nodes. But maybe I just shouldn't use f as a default for g. identity is probably enough.

A limitation of alignedstyle and zippedwalk is that it cannot handle unaligned children excepting ignoring them.

BTW, will you be making the PR for this?

A better example would be how JAX frameworks deserialize weights

I'm not familiar with JAX. Are you talking about pytree and jax.tree_multimap? It seems to only allow tree with exact same structure:

For tree_multimap, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.

and they seems to generally work on flattened array with a tree structure encoding (treedef). flax.from_state_dict looks more like how torch.nn.Module.load_state_dict works. IIUC that means torch.nn.Module.load_state_dictprobably handled the unaligned part more explicitly.

@chengchingwen
Copy link
Owner

@ToucheSir any updates?

@ToucheSir
Copy link
Author

ToucheSir commented Feb 21, 2022

Sorry, nothing from my end yet. I'll need to take a solid chunk of time to sit down and figure out the details of the default align and exact same structure enforcing align for a PR. Until then, I think we're basically in agreement on your points above.

@chengchingwen
Copy link
Owner

@ToucheSir I slightly change the code and will merge it in #4

@CarloLucibello
Copy link
Contributor

So this implemented by the aligned walk and can be closed?

@chengchingwen
Copy link
Owner

Aligned walk is a one possible implementation, but we are not fully convinced it's the best way to do, so I leave this issue open.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants