-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request: make_functional for torch.optim.Optimizer #372
Comments
Thanks for the issue, @teddykoker. This seems like a reasonable API. We've had some decision paralysis in the past on what the best API for this would look like, but having something working is always better than having something that doesn't work :). One question is: If I understand the proposal, |
Thanks for the reply @zou3519 You are correct in that the above proposal does not modify the optimizer state in place. My reasoning behind this was to conform to a more "functional" style and avoid side effects; however this certainly doesn't have to be the case. Regarding size of optimizer states, I believe they are usually on the order of magnitude of the size of the model. For example, Adam maintains a running mean of the gradient values and another running mean of the squared gradient values, resulting in a state roughly two times the size of the model itself. If I understand correctly, PyTorch will need to temporarily have enough memory to store both states, even if the old one is now longer referenced, which could cause memory issues. |
Hi @teddykoker, we open source TorchOpt, which can be combined with functorch to conduct functional optimization. It can be like the feature you want. |
Recently, we also incorporate vmap, one of the major features of Functorch into TorchOpt, by which we achieve batchable optimization. We have a pull request here and provide a colab to play with it. |
Thank you for the great work on this library!
One common pattern I am noticing (a b) is using the gradients from the
grad()
function to perform the optimization step:While this is relatively straightforward to do if vanilla mini-batch gradient descent is desired, there seems to be no way to use other optimization methods without:
.grad
for each parameter and then use the class based optimizer in torch.optimtorch.optim._functional
while manually initializing and passing the necessary elements of the stateOne possible solution to this problem would be to extend
make_functional()
or createmake_functional_optimizer()
to supporttorch.optim.Optimizer
. A potential API could look something like:I believe the above would be possible using a similar method used already in
make_functional
withnn.Module()
. Obviously there are a number of ways the API could work (e.g JAX and optax both have slightly differently structured functional optimizer APIs), but I thought it would be good to gauge interest and/or see if such a thing would be worth implementing!The text was updated successfully, but these errors were encountered: