-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Normalization Layers not interating well with destructure/restructure #1727
Comments
Copying from Slack. The issue is that dm = (layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], σ = nothing), Base.RefValue{Any}((λ = nothing, β = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], γ = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing))),) Since Looking at the blame for the relevant code, it's possible that this has always been an issue, and it only came up now since we added the parameter length checks to Also possible that there are Zygote changes that introduced the |
julia> re(2 .* ones(length(p)))[2].μ
10-element Vector{Float64}:
2.0
2.0
2.0
2.0
2.0
2.0
2.0
2.0
2.0
2.0 so the trouble is in the adjoint, it can be made to traverse down the refs. |
The question I have is have the refs always been there or is that due to #1397? If they've always been there, then it seems the adjoint was silently erroring until we added the warning.
Does this mean making Functors understand |
Refs probably should act like leaves similar to how adding a function destructure(m)
xs = Zygote.Buffer([])
fmap(m) do x
addparams!(xs, x)
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
end
addparams!(xs, x::AbstractArray) = push!(xs, x)
addparams!(xs, x::Base.RefValue) = fmap(x -> addparams!(xs, x), x[])
addparams!(xs, x) = nothing works julia> back(a)
┌ Warning: Expected 70 params, got 50
└ @ Flux ~/Downloads/forks/Flux.jl/src/utils.jl:629
(Float32[0.0; 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) The warning is because gradients wrt to moving stats is |
I think the culprit isn't anything in Flux itself, but Zygote: using Zygote
struct IFoo; inner; end
mutable struct MFoo; inner; end
(f::IFoo)(x) = x + f.inner
(f::MFoo)(x) = x + f.inner
julia> gradient(f-> f(1), IFoo(1))
((inner = 1,),)
julia> gradient(f-> f(1), MFoo(1))
(Base.RefValue{Any}((inner = 1,)),) |
This is expected. Zygote needs to do that since mutable structs need a representation that we can use in the backwards pass reliably. Similar reasoning as to why mutation would copy in reverse mode. |
I figured, more surprised it hasn't come up already (or if it has, more). Adding |
Using Ref like a singleton seems desirable to mark leaf nodes. We can alternatively use something like l.μ .= reshape(μnew, :)
l.σ² .= reshape(σ²new, :) in #1509 and mark |
|
Why would they not be supported? They work just fine. As an example, we train and update models with |
Normal training is fine, but anything that intends to run functor code on structured gradients will not be. Crucially, this includes Optimisers.jl. |
I think I'd run into it while doing data parallel. I have code to check for Refs updating specifically. It builds on FluxML/Optimisers.jl#24 |
In this case, yes they should be zeros |
I was wondering what motivated those commits. Offer for a quick re-review when you think #24 is ready still stands, BTW. |
MWE:
The text was updated successfully, but these errors were encountered: