Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wait_for_first_optimizer_step mode #737

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pytorch_pfn_extras/training/extensions/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LRScheduler(extension.Extension):
stepper (callable): Function that performs the step on
the scheduler.
trigger: Frequency to call this extension.
wait_for_first_optimizer_step (bool): Wait until optimizer.step is called
before invoking scheduler.step. This can address the issue where
optimizer.step is not called from the first iteration when using GradScaler.
linshokaku marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand All @@ -56,14 +59,25 @@ def __init__(
*,
stepper: Any = _default_stepper,
trigger: trigger_module.TriggerLike = (1, "epoch"),
wait_for_first_optimizer_step: bool = False,
is_async: bool = True,
) -> None:
self.scheduler = scheduler
self.trigger = trigger_module.get_trigger(trigger)
self.stepper = stepper
self.wait_for_first_optimizer_step = wait_for_first_optimizer_step
self._stepped = False
linshokaku marked this conversation as resolved.
Show resolved Hide resolved
self.is_async = is_async

def __call__(self, manager: ExtensionsManagerProtocol) -> None:
if not self._stepped:
linshokaku marked this conversation as resolved.
Show resolved Hide resolved
if (
self.wait_for_first_optimizer_step
and self.scheduler.optimizer._step_count < 1
):
return
self._stepped = True

self.stepper(manager, self.scheduler)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from unittest.mock import MagicMock

import pytest
import pytorch_pfn_extras as ppe
Expand Down Expand Up @@ -96,3 +97,34 @@ def test_reduce_lr_on_plateau_no_report():
with pytest.raises(ValueError):
with manager.run_iteration():
pass


def test_lr_scheduler_wait_for_first_optimizer_step():
param = torch.nn.Parameter(torch.zeros(10))
optim = torch.optim.SGD([param], 1.0)
sched = torch.optim.lr_scheduler.MultiStepLR(
optim, milestones=[1, 2, 3], gamma=0.1, last_epoch=-1
)
stepper = MagicMock()
ext = ppe.training.extensions.LRScheduler(
sched,
stepper=stepper,
wait_for_first_optimizer_step=True,
trigger=(1, "iteration"),
)
manager = ppe.training.ExtensionsManager(
{}, {"main": optim}, 1, extensions=[ext], iters_per_epoch=40
)
for i in range(4):
with manager.run_iteration():
pass
assert stepper.call_count == 0
for i in range(4):
with manager.run_iteration(step_optimizers=["main"]):
pass
assert stepper.call_count == 4
for i in range(4):
with manager.run_iteration():
pass

assert stepper.call_count == 8
Loading