Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restoring SequentialLR has undocumented side-effects on Optimizer #119168

Open
ceisenach opened this issue Feb 5, 2024 · 3 comments
Open

Restoring SequentialLR has undocumented side-effects on Optimizer #119168

ceisenach opened this issue Feb 5, 2024 · 3 comments
Labels
actionable bug module: LrScheduler module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ceisenach
Copy link

ceisenach commented Feb 5, 2024

🐛 Describe the bug

When saving and restoring optimizer and LRScheduler states, the order in which the state_dicts are restored determines whether or not the restored optimizer behaves correctly.

Consider the following example

import torch

model = torch.nn.Linear(10, 10)
optim = torch.optim.SGD(model.parameters(), lr=3e-5)
lr = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=cosine_decay_linear_warmup)
lrs = []

for i in range(100):
    optim.step()
    lr.step()
    lrs.append(lr.get_last_lr())

model2 = torch.nn.Linear(10, 10)
optim2 = torch.optim.SGD(model2.parameters(), lr=3e-4)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, T_max=80, eta_min=3e-5)
scheduler1 = torch.optim.lr_scheduler.LinearLR(optim2, start_factor=0.1, end_factor=1, total_iters=20)
lr2 = torch.optim.lr_scheduler.SequentialLR(optim2, schedulers=[scheduler1, scheduler2], milestones=[20])
lrs2 = []
lrs3 = []

for i in range(25):
    optim2.step()
    lr2.step()
    lrs2.append(lr2.get_last_lr())
    lrs3.append(lr2.get_last_lr())

torch.save(lr2.state_dict(), '/home/ubuntu/save_seq2.pt')
torch.save(optim2.state_dict(), '/home/ubuntu/save_optim2.pt')
    
# Correct Behavior
model2 = torch.nn.Linear(10, 10)
optim2 = torch.optim.SGD(model2.parameters(), lr=3e-4)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, T_max=80, eta_min=3e-5)
scheduler1 = torch.optim.lr_scheduler.LinearLR(optim2, start_factor=0.1, end_factor=1, total_iters=20)
lr2 = torch.optim.lr_scheduler.SequentialLR(optim2, schedulers=[scheduler1, scheduler2], milestones=[20])
lr2.load_state_dict(torch.load('/home/ubuntu/save_seq2.pt'))
optim2.load_state_dict(torch.load('/home/ubuntu/save_optim2.pt'))

for i in range(25, 100):
    lr2.step()
    lrs2.append(lr2.get_last_lr())
    
# Incorrect Behavior
model2 = torch.nn.Linear(10, 10)
optim2 = torch.optim.SGD(model2.parameters(), lr=3e-4)
optim2.load_state_dict(torch.load('/home/ubuntu/save_optim2.pt'))
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, T_max=80, eta_min=3e-5)
scheduler1 = torch.optim.lr_scheduler.LinearLR(optim2, start_factor=0.1, end_factor=1, total_iters=20)
lr2 = torch.optim.lr_scheduler.SequentialLR(optim2, schedulers=[scheduler1, scheduler2], milestones=[20])
lr2.load_state_dict(torch.load('/home/ubuntu/save_seq2.pt'))

for i in range(25, 100):
    lr2.step()
    lrs3.append(lr2.get_last_lr())

The first example (with no restore) produces the following learning rate
image

The second example where the optimizer is restored last, the behavior is also correct
image

The third example, where the optimizer is restored first, the behavior is incorrect.
image

This is caused because the SequentialLR has side effects on the optimizer when it is initialized. Other LRSchedulers do not cause the same side-effects (ie order of restoring objects does not matter).

Versions

PyTorch version: 2.2.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.18 | packaged by conda-forge | (default, Dec 23 2023, 17:21:28) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1026-aws-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 3000.000
BogoMIPS: 6000.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] torch==2.2.0+cu118
[pip3] torchaudio==2.2.0+cu118
[pip3] torchvision==0.17.0+cu118
[pip3] triton==2.2.0
[conda] numpy 1.24.1 pypi_0 pypi
[conda] torch 2.2.0+cu118 pypi_0 pypi
[conda] torchaudio 2.2.0+cu118 pypi_0 pypi
[conda] torchvision 0.17.0+cu118 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar

@ceisenach ceisenach changed the title Restoring SequentialLR has undocumented side-effects on Optmizer Restoring SequentialLR has undocumented side-effects on Optimizer Feb 5, 2024
@zou3519 zou3519 added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 8, 2024
@janeyx99 janeyx99 added the bug label Feb 8, 2024
@janeyx99
Copy link
Contributor

janeyx99 commented Feb 8, 2024

This looks not good; we would accept a fix for this!

@lancerts
Copy link
Contributor

lancerts commented Feb 16, 2024

@ceisenach actually, most of the LRSchedulers has the same side effect. It is hidden in super().__init__(optimizer, last_epoch, verbose) which calls

self._initial_step()
,and resets the optimizer step self.optimizer._step_count = 0. It also calls
group.setdefault('initial_lr', group['lr'])
which resets lr (detailed demo below, this changes the state of optim and causes the observed issue).

In particular, this applies to all lr_schduler which are subclasses of

class LRScheduler:
(not just SequentialLR).

I think we should add a note section in the doc state that the load_state_dict should come after the scheduler initalization instead of changing the code. I believe the current behavior is desired.

@janeyx99 lmk what do you think.

@lancerts
Copy link
Contributor

lancerts commented Feb 16, 2024

model2 = torch.nn.Linear(10, 10)
optim2 = torch.optim.SGD(model2.parameters(), lr=3e-4)
optim2.load_state_dict(torch.load(path + 'save_optim2.pt'))

print("1")
print(optim2.state_dict())
scheduler1 = torch.optim.lr_scheduler.LinearLR(optim2, start_factor=0.1, end_factor=1, total_iters=20)

print("2")
print(optim2.state_dict())
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, T_max=80, eta_min=3e-5)

print("3")
print(optim2.state_dict())
lr2 = torch.optim.lr_scheduler.SequentialLR(optim2, schedulers=[scheduler1, scheduler2], milestones=[20])

print("4")
print(optim2.state_dict())
lr2.load_state_dict(torch.load(path + 'save_seq2.pt'))

optim2.load_state_dict(torch.load(path + 'save_optim2.pt')) # This load_state_dict is needed to recover to the correct lr

print("5")
print(optim2.state_dict())

yields

1
{'state': {}, 'param_groups': [{'lr': 0.00029740601285443605, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'initial_lr': 0.0003, 'params': [0, 1]}]}
2
{'state': {}, 'param_groups': [{'lr': 2.9740601285443605e-05, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'initial_lr': 0.0003, 'params': [0, 1]}]}
3
{'state': {}, 'param_groups': [{'lr': 2.9740601285443605e-05, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'initial_lr': 0.0003, 'params': [0, 1]}]}
4
{'state': {}, 'param_groups': [{'lr': 2.9999999999999997e-05, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'initial_lr': 0.0003, 'params': [0, 1]}]}
5
{'state': {}, 'param_groups': [{'lr': 0.00029740601285443605, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'initial_lr': 0.0003, 'params': [0, 1]}]}

Between 1 and 2, the lr gets reset to base lr ,
which is due to

# Initialize epoch and base learning rates
and it is desired.

Therefore, the load_state_dict needs to happen after any lr_scheduler initialization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable bug module: LrScheduler module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants