-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thoughts on zipped tree traversal #2
Comments
Maybe something like: abstract type AlignedStyle{W<:WalkStyle} end
const Align = Union{AlignedStyle, Type{AlignedStyle}}
WalkStyle(::AlignedStyle) = WalkStyle(AlignedStyle)
WalkStyle(::Type{AlignedStyle}) = WalkStyle
function alignedstyle(style::Align, xs...)
fns = map(xs) do x
S = walkstyle(WalkStyle(style), x)
_, fields = S
isnontuple = length(S) <= 2 ? false : S[3]
nchild = isnontuple ? sum(length, fields) : length(fields)
(isnontuple ? Iterators.flatten(fields) : fields), nchild
end
fs = map(x->x[1], fns)
n = minimum(x->x[2], fns)
return identity, zip(fs...), n
end
zippedwalk(f, style::Align, inner_walk, xs...) = zippedwalk(f, f, style, inner_walk, xs...)
function zippedwalk(f, g, style::Align, inner_walk, xs...)
T, a, n = alignedstyle(style, xs...)
isleaf = iszero(n)
if isleaf
return f(xs)
else
return g(T(map(inner_walk, a)))
end
end
zippedpostwalk(f, xs...) = zippedpostwalk(f, AlignedStyle, xs...)
zippedpostwalk(f, style::Align, xs...) = zippedwalk(f, style, x -> zippedpostwalk(f, style, x...), xs...)
julia> a
(x = 0, y = (w = 0, b = 0))
julia> b
(0.5, (0.3, 0.5))
julia> c
(1, 1, 1)
julia> StructWalk.zippedpostwalk(identity, a, b)
2-element Vector{Any}:
(0, 0.5)
[(0, 0.3), (0, 0.5)]
julia> StructWalk.zippedpostwalk(, a, b)^C
julia> StructWalk.zippedpostwalk(identity, a, c)
2-element Vector{Tuple{Any, Int64}}:
(0, 1)
((w = 0, b = 0), 1)
I haven't think too much on the design, but basically the extensible mechanism here is coming from the
Can you elaborate more on this? Assuming that we can load the "aligned"-able part, I wonder what do we want to do with the others? The |
My mistake, I'd misremembered. A better example would be how JAX frameworks deserialize weights or how https://github.com/deepmind/optax runs over multiple trees much like Optimisers.jl. In both cases, I believe they take the easy path and implicitly enforce structural similarity through flattening. |
Should be fine, but some small concern:
A limitation of BTW, will you be making the PR for this?
I'm not familiar with JAX. Are you talking about pytree and
and they seems to generally work on flattened array with a tree structure encoding ( |
@ToucheSir any updates? |
Sorry, nothing from my end yet. I'll need to take a solid chunk of time to sit down and figure out the details of the default align and exact same structure enforcing align for a PR. Until then, I think we're basically in agreement on your points above. |
@ToucheSir I slightly change the code and will merge it in #4 |
So this implemented by the aligned walk and can be closed? |
Aligned walk is a one possible implementation, but we are not fully convinced it's the best way to do, so I leave this issue open. |
I wanted to pick up from our discussion in FluxML/Functors.jl#32, in particular around your comments in FluxML/Functors.jl#32 (comment). I think one concrete example to work off of is a function like
torch.nn.Module.load_state_dict
1. In other words, something that prefers but does not require similarly structured trees (it returns a list of discrepancies if they aren't).I feel like this is absolutely something StructWalk could excel at, but I agree with your point that creating a one-size-fits-all mechanism would be difficult. My question would be whether there might be some extensible mechanism to handle this, just like how
WalkStyle
allows for highly customizable traversal behaviour over individual nodes.Footnotes
this came up in the discussion around restoring saved model weights, ref. @darsnack's post at https://github.com/FluxML/Flux.jl/issues/1027#issuecomment-1034226504. ↩
The text was updated successfully, but these errors were encountered: