Skip to content

Commit

Permalink
Remove save_state_warning in LambdaLR (pytorch#46813)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#46405, pytorch#43352

I updated the docstring in the local file (function level comments). Do I also need to edit somewhere else or recompile docstrings?

Also, though I didn't change any types here, how is typing (for IDE type checking) documentation generated / used)?

Pull Request resolved: pytorch#46813

Reviewed By: ezyang

Differential Revision: D24923112

Pulled By: vincentqb

fbshipit-source-id: be7818e0d4593bfc5d74023b9c361ac2a538589a
  • Loading branch information
jsrozner authored and facebook-github-bot committed Dec 4, 2020
1 parent 714c702 commit 42e6951
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
"https://github.com/pytorch/pytorch/issues/new/choose."
)

SAVE_STATE_WARNING = "Please also save or load the state of the optimizer when saving or loading the scheduler."

class _LRScheduler(object):

def __init__(self, optimizer, last_epoch=-1, verbose=False):
Expand Down Expand Up @@ -211,9 +209,10 @@ def state_dict(self):
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
"""

warnings.warn(SAVE_STATE_WARNING, UserWarning)
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)

Expand All @@ -226,12 +225,13 @@ def state_dict(self):
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""

warnings.warn(SAVE_STATE_WARNING, UserWarning)
lr_lambdas = state_dict.pop('lr_lambdas')
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
Expand Down

0 comments on commit 42e6951

Please sign in to comment.