Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: make_functional for torch.optim.Optimizer #372

Open
teddykoker opened this issue Jan 1, 2022 · 4 comments
Open

Feature Request: make_functional for torch.optim.Optimizer #372

teddykoker opened this issue Jan 1, 2022 · 4 comments

Comments

@teddykoker
Copy link

Thank you for the great work on this library!

One common pattern I am noticing (a b) is using the gradients from the grad() function to perform the optimization step:

grads = grad(loss_fn)(params, ...)
params = [p - g * learning_rate for p, g, in zip(params, grads)]

While this is relatively straightforward to do if vanilla mini-batch gradient descent is desired, there seems to be no way to use other optimization methods without:

  1. Manually setting the .grad for each parameter and then use the class based optimizer in torch.optim
  2. Use the functional interface for the optimizer implemented in torch.optim._functional while manually initializing and passing the necessary elements of the state
  3. Implementing the optimizer yourself

One possible solution to this problem would be to extend make_functional() or createmake_functional_optimizer() to support torch.optim.Optimizer. A potential API could look something like:

optimizer = torch.optim.Adam(params, lr=3e-4)

# state contains the optimizer state
# update_fn is a stateless func that will return a new state and params given gradients and a state
opt_state, update_fn = make_functional(optimizer)

grads = grad(loss_fn)(params, ...)

# update params and state using update_fn
params, opt_state = update_fn(grads, opt_state)

I believe the above would be possible using a similar method used already in make_functional with nn.Module(). Obviously there are a number of ways the API could work (e.g JAX and optax both have slightly differently structured functional optimizer APIs), but I thought it would be good to gauge interest and/or see if such a thing would be worth implementing!

@zou3519
Copy link
Contributor

zou3519 commented Jan 4, 2022

Thanks for the issue, @teddykoker. This seems like a reasonable API. We've had some decision paralysis in the past on what the best API for this would look like, but having something working is always better than having something that doesn't work :).

One question is: If I understand the proposal, update_fn(grads, opt_state) does modify opt_state in-place but rather returns a new opt_state. Do you think it'll be a problem that this doesn't do the in-place modification? I don't have a sense of how large optimizer states can be in general.

@teddykoker
Copy link
Author

Thanks for the reply @zou3519 You are correct in that the above proposal does not modify the optimizer state in place. My reasoning behind this was to conform to a more "functional" style and avoid side effects; however this certainly doesn't have to be the case.

Regarding size of optimizer states, I believe they are usually on the order of magnitude of the size of the model. For example, Adam maintains a running mean of the gradient values and another running mean of the squared gradient values, resulting in a state roughly two times the size of the model itself. If I understand correctly, PyTorch will need to temporarily have enough memory to store both states, even if the old one is now longer referenced, which could cause memory issues.

@waterhorse1
Copy link

Hi @teddykoker, we open source TorchOpt, which can be combined with functorch to conduct functional optimization. It can be like the feature you want.

@waterhorse1
Copy link

Recently, we also incorporate vmap, one of the major features of Functorch into TorchOpt, by which we achieve batchable optimization. We have a pull request here and provide a colab to play with it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants