[bug] learning rate schedulers behave unexpectedly with pytorch 2.0.0 #3202
Closed
Description
Issue Description
Learning rate scheduler does not seem to behave as expected, and importantly, the learning rate schedule is different depending on whether you use pytorch version 1.13.0 versus 2.0.0
Environment
- macOS Monterey
- Python 3.8
- pyro dev version
pyro-ppl 1.8.4+dd4e0f81
- pytorch versions specified below
Code Snippet
import pyro
import torch
# dummy dataloader
dataset = torch.randn((100))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)
epochs = 10
def model(x):
loc = pyro.param('loc', torch.tensor(1.))
with pyro.plate('plate', x.shape[0]):
pyro.sample('obs', pyro.distributions.Normal(loc, 1.0), obs=x)
def guide(x):
pass
optimizer = torch.optim.Adam
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.1}, 'gamma': 0.9})
svi = pyro.infer.SVI(model, guide, scheduler, loss=pyro.infer.Trace_ELBO())
for i in range(epochs):
for minibatch in dataloader:
svi.step(minibatch)
lr = list(scheduler.optim_objs.values())[0].get_last_lr()[0]
print(f'[{i + 1:03d}] lr = {lr:.3e}')
svi.optim.step()
for torch 2.0.0
, I get
[001] lr = 5.905e-02
[002] lr = 3.138e-02
[003] lr = 1.668e-02
[004] lr = 8.863e-03
[005] lr = 4.710e-03
[006] lr = 2.503e-03
[007] lr = 1.330e-03
[008] lr = 7.070e-04
[009] lr = 3.757e-04
[010] lr = 1.997e-04
accompanied by the warning
UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
while for torch 1.13.0
, I get
[001] lr = 1.000e-01
[002] lr = 9.000e-02
[003] lr = 8.100e-02
[004] lr = 7.290e-02
[005] lr = 6.561e-02
[006] lr = 5.905e-02
[007] lr = 5.314e-02
[008] lr = 4.783e-02
[009] lr = 4.305e-02
[010] lr = 3.874e-02
The latter is a lot closer to what I'd expect, given that
0.1 * 0.9**10 = 0.0348
I also don't get the warning from pytorch 1.13.0
If I do what I think should be the same thing in pytorch 1.13.0 without any pyro code, I see the following learning rate schedule that agrees with the pyro + pytorch 1.13.0 version above:
import torch
# dummy dataloader
dataset = torch.randn((100))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)
epochs = 10
x = torch.nn.Parameter(torch.tensor(1.))
optimizer = torch.optim.Adam([x], lr=0.1)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
loss_fn = torch.nn.MSELoss()
for i in range(epochs):
for minibatch in dataloader:
optimizer.zero_grad()
loss = loss_fn(minibatch, x)
loss.backward()
optimizer.step()
lr = scheduler.get_last_lr()[0]
print(f'[{i + 1:03d}] lr = {lr:.3e}')
scheduler.step()
gives me, using pytorch 1.13.0,
[001] lr = 1.000e-01
[002] lr = 9.000e-02
[003] lr = 8.100e-02
[004] lr = 7.290e-02
[005] lr = 6.561e-02
[006] lr = 5.905e-02
[007] lr = 5.314e-02
[008] lr = 4.783e-02
[009] lr = 4.305e-02
[010] lr = 3.874e-02
and the above with pytorch 2.0.0 gives me the same thing:
[001] lr = 1.000e-01
[002] lr = 9.000e-02
[003] lr = 8.100e-02
[004] lr = 7.290e-02
[005] lr = 6.561e-02
[006] lr = 5.905e-02
[007] lr = 5.314e-02
[008] lr = 4.783e-02
[009] lr = 4.305e-02
[010] lr = 3.874e-02
So I think this is somehow related to the pyro + pytorch interface.
Metadata
Assignees
Labels
No labels