Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalization Layers not interating well with destructure/restructure #1727

Closed
avik-pal opened this issue Sep 30, 2021 · 14 comments · Fixed by #1901
Closed

Normalization Layers not interating well with destructure/restructure #1727

avik-pal opened this issue Sep 30, 2021 · 14 comments · Fixed by #1901

Comments

@avik-pal
Copy link
Member

MWE:

using Flux

model = Chain(Dense(2, 10, relu), BatchNorm(10))
p, re = Flux.destructure(model)

x = rand(Float32, 2, 1)

a, back = Flux.pullback(x, p) do _x, _p
    vec(re(_p)(_x))
end

back(a)
# ┌ Warning: Expected 70 params, got 30
# └ @ Flux ~/.julia/packages/Flux/ZnXxS/src/utils.jl:623
@darsnack
Copy link
Member

Copying from Slack. The issue is that dm passed to the pullback of _restructure looks like:

dm = (layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], σ = nothing), Base.RefValue{Any}((λ = nothing, β = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], γ = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing))),)

Since fmap won't enter Base.RefValue, destructure(dm) will not create a flat vector of parameters that is the same as destructure(m) (it leaves out the norm layers).

Looking at the blame for the relevant code, it's possible that this has always been an issue, and it only came up now since we added the parameter length checks to _restructure and its pullback. Printing out dm but with Flux v0.11.6 will tell us if the RefValue has always been around, or if it was introduced in #1397.

Also possible that there are Zygote changes that introduced the RefValue, but I don't know enough to comment.

@DhairyaLGandhi
Copy link
Member

re(p) works as expected

julia> re(2 .* ones(length(p)))[2].μ
10-element Vector{Float64}:
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0

so the trouble is in the adjoint, it can be made to traverse down the refs.

@darsnack
Copy link
Member

The question I have is have the refs always been there or is that due to #1397? If they've always been there, then it seems the adjoint was silently erroring until we added the warning.

it can be made to traverse down the refs

Does this mean making Functors understand Ref?

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 30, 2021

Refs probably should act like leaves similar to how adding a Ref in broadcasting makes it behave like a singleton.

function destructure(m)
  xs = Zygote.Buffer([])
  fmap(m) do x
    addparams!(xs, x)
    return x
  end
  return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
end

addparams!(xs, x::AbstractArray) = push!(xs, x)
addparams!(xs, x::Base.RefValue) = fmap(x -> addparams!(xs, x), x[])
addparams!(xs, x) = nothing

works

julia> back(a)
┌ Warning: Expected 70 params, got 50
└ @ Flux ~/Downloads/forks/Flux.jl/src/utils.jl:629
(Float32[0.0; 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0    0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

The warning is because gradients wrt to moving stats is nothing. @avik-pal would you expect it to fill it in with zeros in that case?

@ToucheSir
Copy link
Member

I think the culprit isn't anything in Flux itself, but Zygote:

using Zygote

struct IFoo; inner; end
mutable struct MFoo; inner; end

(f::IFoo)(x) = x + f.inner
(f::MFoo)(x) = x + f.inner

julia> gradient(f-> f(1), IFoo(1))
((inner = 1,),)

julia> gradient(f-> f(1), MFoo(1))
(Base.RefValue{Any}((inner = 1,)),)

@DhairyaLGandhi
Copy link
Member

This is expected. Zygote needs to do that since mutable structs need a representation that we can use in the backwards pass reliably. Similar reasoning as to why mutation would copy in reverse mode.

@ToucheSir
Copy link
Member

I figured, more surprised it hasn't come up already (or if it has, more). Adding @functor RefValue in Functors seems pretty straightforward.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 30, 2021

Using Ref like a singleton seems desirable to mark leaf nodes. We can alternatively use something like

l.μ .= reshape(μnew, :)
l.σ² .= reshape(σ²new, :)

in #1509 and mark BatchNorm immutable. I guess the active would be a holdout, but maybe that can be handled separately via NormConfig.

@ToucheSir
Copy link
Member

ToucheSir commented Sep 30, 2021

active would still need to be mutable, so maybe a Ref. Also, we'd need to add a huge blinking warning to the docs telling people that mutable structs aren't supported in many parts of Flux. Also also, mutating the params in-place precludes using something like StaticArrays with norm layers.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 30, 2021

Why would they not be supported? They work just fine. As an example, we train and update models with BatchNorm (a mutable) very regularly. This is a bug with destructure.

@ToucheSir
Copy link
Member

Normal training is fine, but anything that intends to run functor code on structured gradients will not be. Crucially, this includes Optimisers.jl.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 30, 2021

I think I'd run into it while doing data parallel. I have code to check for Refs updating specifically. It builds on FluxML/Optimisers.jl#24

@avik-pal
Copy link
Member Author

avik-pal commented Oct 1, 2021

The warning is because gradients wrt to moving stats is nothing. @avik-pal would you expect it to fill it in with zeros in that case?

In this case, yes they should be zeros

@ToucheSir
Copy link
Member

I think I'd run into it while doing data parallel. I have code to check for Refs updating specifically. It builds on FluxML/Optimisers.jl#24

I was wondering what motivated those commits. Offer for a quick re-review when you think #24 is ready still stands, BTW.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants