Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no gradients if we save the Flux.params into a variable #1346

Open
MariusDrulea opened this issue Dec 27, 2022 · 5 comments
Open

no gradients if we save the Flux.params into a variable #1346

MariusDrulea opened this issue Dec 27, 2022 · 5 comments
Labels
implicit using Params, Grads

Comments

@MariusDrulea
Copy link

MariusDrulea commented Dec 27, 2022

See the following MWE:

using Flux

model = Dense(2, 2)

xt = rand(Float32, 2, 4) # batch size of 4
yt = rand(Float32, 2, 4)

ps = Flux.params(model)
loss_fun(m, x, y) = 1/2*sum(p->sum(p.^2), ps)

loss_fun_explicit(m, x, y) = 1/2*sum(m.weight.^2) + 1/2*sum(m.bias.^2)

loss_fun_slow(m, x, y) = 1/2*sum(p->sum(p.^2), Flux.params(m))

∇m = gradient(m->loss_fun(m, xt, yt), model)    
∇m_explicit = gradient(m->loss_fun_explicit(m, xt, yt), model)    
∇m_slow = gradient(m->loss_fun_slow(m, xt, yt), model)    

@show ∇m
@show ∇m_explicit
@show ∇m_slow

The values of the gradients are bellow. ∇m_explicit and ∇m_slow are equal and correct, but ∇m is nothing.

∇m = (nothing,)
∇m_explicit = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)
∇m_slow = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)
@MariusDrulea MariusDrulea changed the title incorrect gradients incorrect gradients for explicit regularisation Dec 27, 2022
@MariusDrulea MariusDrulea changed the title incorrect gradients for explicit regularisation incorrect gradients if we save the Flux.params into a variable (edit: might be the desired behaviour) Dec 27, 2022
@MariusDrulea
Copy link
Author

Edit after posting the issue. The behavior might be the expected one. I think it must be the case as we want to use only explicit loss functions.

@MariusDrulea MariusDrulea changed the title incorrect gradients if we save the Flux.params into a variable (edit: might be the desired behaviour) no gradients if we save the Flux.params into a variable (edit: it is the desired behaviour) Dec 27, 2022
@ToucheSir
Copy link
Member

It looks like the explicit parameters version is actually correct and the other two are wrong, because they give the same answer when you remove the regularization term. I'm trying to figure out why no gradient is being returned for params, because FluxML/Flux.jl#2118 was explicitly written to allow AD when using explicit params.

@MariusDrulea MariusDrulea changed the title no gradients if we save the Flux.params into a variable (edit: it is the desired behaviour) no gradients if we save the Flux.params into a variable Dec 27, 2022
@MariusDrulea
Copy link
Author

@ToucheSir just noticed the correct way to call the implicit function is like this ∇m = gradient(()->loss_fun(), ps), we have to provide a function with no arguments and also the ps variable. If I do so, I get the same gradients values as for the explicit versions.

@ToucheSir
Copy link
Member

Yes, but as seen in your edited example you can also call params on an explicit model. The trouble comes when you try to iterate over an external (in this case global) variable such as ps or model, because Zygote can't see a path back from those to any of the inputs. The question is whether we can catch such accesses and warn/error as appropriate. My only idea so far is to add a warning which shows up when differentiating params in explicit mode that links to a docs section outlining what works and what doesn't.

@mcabbott
Copy link
Member

The answers presently above look correct to me.

Perhaps a simpler example of what they illustrate is this. None of these seem wrong, but the ones mixing explicit arguments and global references are perhaps surprising.

julia> using Zygote, LinearAlgebra, ForwardDiff

julia> v = [2.0, 3.0];

julia> gradient(x -> dot(x,x), v)
([4.0, 6.0],)

julia> gradient(x -> dot(x,v), v)  # one global reference
([2.0, 3.0],)

julia> ForwardDiff.gradient(x -> dot(x,v), v)  # agrees
2-element Vector{Float64}:
 2.0
 3.0

julia> gradient(x -> dot(v,v), v)  # two global references
(nothing,)

julia> ForwardDiff.gradient(x -> dot(v,v), v)  # agrees
2-element Vector{Float64}:
 0.0
 0.0

julia> gradient(() -> dot(v,v), Params([v]))  # implicit mode
Grads(...)

julia> ans[v]  # same answer as the first, but via global ref.
2-element Vector{Float64}:
 4.0
 6.0

@mcabbott mcabbott added the implicit using Params, Grads label Mar 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
implicit using Params, Grads
Projects
None yet
Development

No branches or pull requests

3 participants