Skip to content

[bug] learning rate schedulers behave unexpectedly with pytorch 2.0.0 #3202

Closed
@sjfleming

Description

Issue Description

Learning rate scheduler does not seem to behave as expected, and importantly, the learning rate schedule is different depending on whether you use pytorch version 1.13.0 versus 2.0.0

Environment

  • macOS Monterey
  • Python 3.8
  • pyro dev version pyro-ppl 1.8.4+dd4e0f81
  • pytorch versions specified below

Code Snippet

import pyro
import torch

# dummy dataloader
dataset = torch.randn((100))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)

epochs = 10


def model(x):
    loc = pyro.param('loc', torch.tensor(1.))
    with pyro.plate('plate', x.shape[0]):
        pyro.sample('obs', pyro.distributions.Normal(loc, 1.0), obs=x)


def guide(x):
    pass


optimizer = torch.optim.Adam
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.1}, 'gamma': 0.9})
svi = pyro.infer.SVI(model, guide, scheduler, loss=pyro.infer.Trace_ELBO())

for i in range(epochs):
    for minibatch in dataloader:
        svi.step(minibatch)
    lr = list(scheduler.optim_objs.values())[0].get_last_lr()[0]
    print(f'[{i + 1:03d}]  lr = {lr:.3e}')
    svi.optim.step()

for torch 2.0.0, I get

[001]  lr = 5.905e-02
[002]  lr = 3.138e-02
[003]  lr = 1.668e-02
[004]  lr = 8.863e-03
[005]  lr = 4.710e-03
[006]  lr = 2.503e-03
[007]  lr = 1.330e-03
[008]  lr = 7.070e-04
[009]  lr = 3.757e-04
[010]  lr = 1.997e-04

accompanied by the warning

UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "

while for torch 1.13.0, I get

[001]  lr = 1.000e-01
[002]  lr = 9.000e-02
[003]  lr = 8.100e-02
[004]  lr = 7.290e-02
[005]  lr = 6.561e-02
[006]  lr = 5.905e-02
[007]  lr = 5.314e-02
[008]  lr = 4.783e-02
[009]  lr = 4.305e-02
[010]  lr = 3.874e-02

The latter is a lot closer to what I'd expect, given that

0.1 * 0.9**10 = 0.0348

I also don't get the warning from pytorch 1.13.0

If I do what I think should be the same thing in pytorch 1.13.0 without any pyro code, I see the following learning rate schedule that agrees with the pyro + pytorch 1.13.0 version above:

import torch

# dummy dataloader
dataset = torch.randn((100))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)

epochs = 10

x = torch.nn.Parameter(torch.tensor(1.))
optimizer = torch.optim.Adam([x], lr=0.1)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
loss_fn = torch.nn.MSELoss()

for i in range(epochs):
    for minibatch in dataloader:
        optimizer.zero_grad()
        loss = loss_fn(minibatch, x)
        loss.backward()
        optimizer.step()
    lr = scheduler.get_last_lr()[0]
    print(f'[{i + 1:03d}]  lr = {lr:.3e}')
    scheduler.step()

gives me, using pytorch 1.13.0,

[001]  lr = 1.000e-01
[002]  lr = 9.000e-02
[003]  lr = 8.100e-02
[004]  lr = 7.290e-02
[005]  lr = 6.561e-02
[006]  lr = 5.905e-02
[007]  lr = 5.314e-02
[008]  lr = 4.783e-02
[009]  lr = 4.305e-02
[010]  lr = 3.874e-02

and the above with pytorch 2.0.0 gives me the same thing:

[001]  lr = 1.000e-01
[002]  lr = 9.000e-02
[003]  lr = 8.100e-02
[004]  lr = 7.290e-02
[005]  lr = 6.561e-02
[006]  lr = 5.905e-02
[007]  lr = 5.314e-02
[008]  lr = 4.783e-02
[009]  lr = 4.305e-02
[010]  lr = 3.874e-02

So I think this is somehow related to the pyro + pytorch interface.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions