Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional arguments to loss function? #1730

Closed
KenZhenLin opened this issue Oct 1, 2021 · 8 comments
Closed

additional arguments to loss function? #1730

KenZhenLin opened this issue Oct 1, 2021 · 8 comments

Comments

@KenZhenLin
Copy link

How do I use Flux.Optimise.train! when my loss function has arguments in addition to x and y?

The code below works for loss function with x and y being the only arguments, that is, loss(x, y)

data = DataLoader((x, y), batchsize=2, shuffle=true, partial=true)
opt = ADAM(0.01, (0.9, 0.999))
for epoch = 1:100
Flux.Optimise.train!(loss, params(W, b), data, opt)
println(loss(x, y))
end

What if my loss function is in the form loss(x, y, z, W, b, c, d) where x, y, z are the data for each batch, and W, b are the parameters I want to train, and c, d are the other fixed constant parameters for the loss function? How do I write/modify the above training code for this loss function?

Thanks!

@DhairyaLGandhi
Copy link
Member

There's many ways! First a little bit of context. The data you send in train! is in the form of [itr1, itr2, ...]. itr here is a collection of arguments to the loss function. Effectively, this will call loss(itr...). So you need itr to be the form (x, y, a, b, ...). If you're using a DataLoader, you want it to produce itr on iteration (hopefully that makes sense).

Another way, in the case you wish to supplement arguments that don't change, you can write your loss function as (x...) -> loss(x..., arg1, arg2...).

See #1530 for how data loaders can be made to produce itr and DataLoaders.jl also.

@KenZhenLin
Copy link
Author

Another way, in the case you wish to supplement arguments that don't change, you can write your loss function as (x...) -> loss(x..., arg1, arg2...)

Thanks a lot!

Also it seems that the parameters being updated have to be arrays. I observed that if W is an array, but b in params(W, b) is a real number, then params will ignore b, and b will not be updated during the training. I had to make b a 1-d array to make things work. I noticed this by accident, so I wonder if this is the only restriction I need to pay attention to when doing the training.

@darsnack
Copy link
Member

darsnack commented Oct 1, 2021

Yes, for now all parameters need to an <:AbstractArray{<:Number} for params specifically. In general, for the optimisers to work, the parameters must be mutable and have basic arithmetic defined (like +). That basically leaves you with <:AbstractArray{<:Number} though not strictly so. Scalar values are not mutable, so the optimizer won't be able to "update" the value.

All these restrictions will probably go away in the long term, but for now they are there. The only other restriction I can think of is that your loss function must return a scalar value. If it returns a vector value, then you can still use Flux + its AD system + optimizers to update the model, but you won't be able to use train!.

I am closing this issue, since there doesn't seem to be anything actionable here for Flux development.

@darsnack darsnack closed this as completed Oct 1, 2021
@KenZhenLin
Copy link
Author

KenZhenLin commented Oct 2, 2021

Yes, for now all parameters need to an <:AbstractArray{<:Number} for params specifically.

Thanks! I just would like to make sure I'm doing the right thing:

Suppose my loss functions is loss(x, W, y, b, c), where x and y are the data, W and b are the parameters (arrays that are usually 2d or 1d) to be updated, and c is some constant coefficients of the function.

I'm now able to do the following:

data = DataLoader((x, y), batchsize=32, shuffle=true, partial=true)
opt = ADAM(0.01, (0.9, 0.999))
for epoch = 1:100
Flux.Optimise.train!((x, y) - > loss(x, W, y, b, c), params(W, b), data, opt)
end

The above code can run, but I am just wondering if this is actually the correct syntax to do it, and if there are things silently breaking down like the issue of scalar being ignored during training, and how do I detect such silent issues in general if there are no warning or error message.

For example, the documentation on training uses params as params(m) for some model m, not in the way I'm doing here:param(W, b), which makes me less confident about my code.

@darsnack
Copy link
Member

darsnack commented Oct 2, 2021

params(W, b) is not a common use case but absolutely a valid one. As for the scalar in your loss function, this is perfectly fine. The issue is only if the scalar is a parameter that you wish to update. Since you are doing params(W, b) which are both arrays, you should be fine even if c is a scalar. Zygote (the AD system) can differentiate with respect to scalars just fine.

@ToucheSir
Copy link
Member

Just to hammer Kyle's point home, this:

Suppose my loss functions is loss(x, W, y, b, c), where x and y are the data, W and b are the parameters (arrays that are usually 2d or 1d) to be updated, and c is some constant coefficients of the function
...
if there are things silently breaking down like the issue of scalar being ignored during training

Will never happen because it's exactly what an AD framework like Zygote is made to handle. What would happen is that you don't get a gradient for c when using Flux.params (which calls Zygote.Params) because it's a scalar value. But as you said, c is a constant and shouldn't have a gradient anyhow.

Why not allow gradients for scalar-valued parameters, you might ask? That's a great question, and it comes down to design trade-offs. You might have noticed that most ML libraries expect you to use their own custom array/tensor/variable types. This is no accident: they need those types to be able to keep track of parameters for AD. The huge downside, of course, is that it makes them incompatible with external libraries without tedious manual conversion.

As a source-to-source AD, Zygote avoids this problem entirely. You can pass in native Julia values like Arrays and plain numbers, and Zygote will happily calculate gradients for you. These are what we call "explicit" gradients, and you'll notice both the Flux and Zygote tutorials demonstrate how to work with them.

But what if you want to train a model that has millions of parameters contained in hundreds of arrays? Passing them in as individual arguments is untenable, so Params exists as a way of collecting a large number of "implicit" parameters for gradient. For how this is done under the hood, see this Discourse thread. That post should also give you an intuition of why implicit params don't play well with scalar values: they don't have a stable memory address. In other words, if you write:

f(x, a, b) = ax + b
a = 1
b = 1
gs = gradient(() -> f(5), Params([a, b]))
∇a, ∇b = gs[a], gs[b]

It's impossible for gs to distinguish between a and b because they have the same object identity. Moreover, how do you update a and b without access to original variables? If they were fields of an immutable struct, you'd have to create an entirely new struct with just those 2 fields updated. Flux isn't in the business of mandating all your structs should be mutable, and neither should we be!

Now all of the above rarely comes up in most ML models because all of the params are arrays, so implicit params were considered a reasonable trade-off. However, we are working to make all kinds of parameters work as part of https://github.com/FluxML/Optimisers.jl. Feel free to open a discussion topic there if you have any questions!

@KenZhenLin
Copy link
Author

Will never happen because it's exactly what an AD framework like Zygote is made to handle. What would happen is that you don't get a gradient for c when using Flux.params (which calls Zygote.Params) because it's a scalar value.

Thanks for the useful comments and information! I will be more careful. The reason I am worried about silent errors is for example I did not get a warning or error message when I did params(W, b) where b is a scalar. I realized that b had been ignored when I "accidentally" noticed the updated b was the same as its starting value. So I found that issue really because of luck, which is why I'm wondering if there are other issues I have not noticed.

@darsnack
Copy link
Member

darsnack commented Oct 2, 2021

The reason I am worried about silent errors is for example I did not get a warning or error message when I did params(W, b) where b is a scalar.

This is good feedback, and you are not the first user to have been bitten by this silent error. I have created #1731 to fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants