Skip to content

Commit 72ab1d0

Browse files
authored
Fixed torch LRScheduler issue and fixed CI (#2780)
* Fixed torch LRScheduler issue and fixed CI Fixes #2773 * Fixed mypy issues
1 parent 6b8ebca commit 72ab1d0

File tree

6 files changed

+50
-33
lines changed

6 files changed

+50
-33
lines changed

.github/workflows/unit-tests.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ jobs:
9696
bash ./tests/run_code_style.sh lint
9797
9898
- name: Run Mypy
99-
if: ${{ matrix.os == 'ubuntu-latest'}}
99+
# https://github.com/pytorch/ignite/pull/2780
100+
#
101+
if: ${{ matrix.os == 'ubuntu-latest' && matrix.pytorch-channel == 'pytorch-nightly'}}
100102
run: |
101103
bash ./tests/run_code_style.sh mypy
102104

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def run(self):
342342
("py:class", "torch.utils.data.sampler.BatchSampler"),
343343
("py:class", "torch.cuda.amp.grad_scaler.GradScaler"),
344344
("py:class", "torch.optim.lr_scheduler._LRScheduler"),
345+
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
345346
("py:class", "torch.utils.data.dataloader.DataLoader"),
346347
]
347348

ignite/contrib/engines/common.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55

66
import torch
77
import torch.nn as nn
8-
from torch.optim.lr_scheduler import _LRScheduler
98
from torch.optim.optimizer import Optimizer
109
from torch.utils.data.distributed import DistributedSampler
1110

11+
# https://github.com/pytorch/ignite/issues/2773
12+
try:
13+
from torch.optim.lr_scheduler import LRScheduler as PyTorchLRScheduler
14+
except ImportError:
15+
from torch.optim.lr_scheduler import _LRScheduler as PyTorchLRScheduler
16+
1217
import ignite.distributed as idist
1318
from ignite.contrib.handlers import (
1419
ClearMLLogger,
@@ -37,7 +42,7 @@ def setup_common_training_handlers(
3742
to_save: Optional[Mapping] = None,
3843
save_every_iters: int = 1000,
3944
output_path: Optional[str] = None,
40-
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
45+
lr_scheduler: Optional[Union[ParamScheduler, PyTorchLRScheduler]] = None,
4146
with_gpu_stats: bool = False,
4247
output_names: Optional[Iterable[str]] = None,
4348
with_pbars: bool = True,
@@ -140,7 +145,7 @@ def _setup_common_training_handlers(
140145
to_save: Optional[Mapping] = None,
141146
save_every_iters: int = 1000,
142147
output_path: Optional[str] = None,
143-
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
148+
lr_scheduler: Optional[Union[ParamScheduler, PyTorchLRScheduler]] = None,
144149
with_gpu_stats: bool = False,
145150
output_names: Optional[Iterable[str]] = None,
146151
with_pbars: bool = True,
@@ -160,9 +165,9 @@ def _setup_common_training_handlers(
160165
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
161166

162167
if lr_scheduler is not None:
163-
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
168+
if isinstance(lr_scheduler, PyTorchLRScheduler):
164169
trainer.add_event_handler(
165-
Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
170+
Events.ITERATION_COMPLETED, lambda engine: cast(PyTorchLRScheduler, lr_scheduler).step()
166171
)
167172
else:
168173
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
@@ -226,7 +231,7 @@ def _setup_common_distrib_training_handlers(
226231
to_save: Optional[Mapping] = None,
227232
save_every_iters: int = 1000,
228233
output_path: Optional[str] = None,
229-
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
234+
lr_scheduler: Optional[Union[ParamScheduler, PyTorchLRScheduler]] = None,
230235
with_gpu_stats: bool = False,
231236
output_names: Optional[Iterable[str]] = None,
232237
with_pbars: bool = True,

ignite/handlers/param_scheduler.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
1111

1212
import torch
13-
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
13+
from torch.optim.lr_scheduler import ReduceLROnPlateau
1414
from torch.optim.optimizer import Optimizer
1515

16+
# https://github.com/pytorch/ignite/issues/2773
17+
try:
18+
from torch.optim.lr_scheduler import LRScheduler as PyTorchLRScheduler
19+
except ImportError:
20+
from torch.optim.lr_scheduler import _LRScheduler as PyTorchLRScheduler
21+
1622
from ignite.engine import Engine
1723

1824

@@ -838,14 +844,15 @@ def print_lr():
838844

839845
def __init__(
840846
self,
841-
lr_scheduler: _LRScheduler,
847+
lr_scheduler: PyTorchLRScheduler,
842848
save_history: bool = False,
843849
use_legacy: bool = False,
844850
):
845851

846-
if not isinstance(lr_scheduler, _LRScheduler):
852+
if not isinstance(lr_scheduler, PyTorchLRScheduler):
847853
raise TypeError(
848-
"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
854+
"Argument lr_scheduler should be a subclass of "
855+
f"torch.optim.lr_scheduler.{PyTorchLRScheduler.__name__}, "
849856
f"but given {type(lr_scheduler)}"
850857
)
851858

@@ -882,7 +889,7 @@ def get_param(self) -> Union[float, List[float]]:
882889

883890
@classmethod
884891
def simulate_values( # type: ignore[override]
885-
cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: Any
892+
cls, num_events: int, lr_scheduler: PyTorchLRScheduler, **kwargs: Any
886893
) -> List[List[int]]:
887894
"""Method to simulate scheduled values during num_events events.
888895
@@ -894,13 +901,14 @@ def simulate_values( # type: ignore[override]
894901
event_index, value
895902
"""
896903

897-
if not isinstance(lr_scheduler, _LRScheduler):
904+
if not isinstance(lr_scheduler, PyTorchLRScheduler):
898905
raise TypeError(
899-
"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
906+
"Argument lr_scheduler should be a subclass of "
907+
f"torch.optim.lr_scheduler.{PyTorchLRScheduler.__name__}, "
900908
f"but given {type(lr_scheduler)}"
901909
)
902910

903-
# This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
911+
# This scheduler uses `torch.optim.lr_scheduler.LRScheduler` which
904912
# should be replicated in order to simulate LR values and
905913
# not perturb original scheduler.
906914
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -926,7 +934,7 @@ def simulate_values( # type: ignore[override]
926934

927935

928936
def create_lr_scheduler_with_warmup(
929-
lr_scheduler: Union[ParamScheduler, _LRScheduler],
937+
lr_scheduler: Union[ParamScheduler, PyTorchLRScheduler],
930938
warmup_start_value: float,
931939
warmup_duration: int,
932940
warmup_end_value: Optional[float] = None,
@@ -995,10 +1003,11 @@ def print_lr():
9951003
9961004
.. versionadded:: 0.4.5
9971005
"""
998-
if not isinstance(lr_scheduler, (ParamScheduler, _LRScheduler)):
1006+
if not isinstance(lr_scheduler, (ParamScheduler, PyTorchLRScheduler)):
9991007
raise TypeError(
1000-
"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler or "
1001-
f"ParamScheduler, but given {type(lr_scheduler)}"
1008+
"Argument lr_scheduler should be a subclass of "
1009+
f"torch.optim.lr_scheduler.{PyTorchLRScheduler.__name__} or ParamScheduler, "
1010+
f"but given {type(lr_scheduler)}"
10021011
)
10031012

10041013
if not isinstance(warmup_duration, numbers.Integral):
@@ -1018,7 +1027,7 @@ def print_lr():
10181027

10191028
milestones_values = [(0, warmup_start_value), (warmup_duration - 1, param_group_warmup_end_value)]
10201029

1021-
if isinstance(lr_scheduler, _LRScheduler):
1030+
if isinstance(lr_scheduler, PyTorchLRScheduler):
10221031
init_lr = param_group["lr"]
10231032
if init_lr != param_group_warmup_end_value:
10241033
milestones_values.append((warmup_duration, init_lr))
@@ -1054,7 +1063,7 @@ def print_lr():
10541063
schedulers = [
10551064
warmup_scheduler,
10561065
lr_scheduler,
1057-
] # type: List[Union[ParamScheduler, ParamGroupScheduler, _LRScheduler]]
1066+
] # type: List[Union[ParamScheduler, ParamGroupScheduler, PyTorchLRScheduler]]
10581067
durations = [milestones_values[-1][0] + 1]
10591068
combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history)
10601069

@@ -1381,7 +1390,9 @@ def load_state_dict(self, state_dict: Mapping) -> None:
13811390
s.load_state_dict(sd)
13821391

13831392
@classmethod
1384-
def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwargs: Any) -> List[List[int]]:
1393+
def simulate_values(
1394+
cls, num_events: int, schedulers: List[ParamScheduler], **kwargs: Any
1395+
) -> List[List[Union[List[float], float, int]]]:
13851396
"""Method to simulate scheduled values during num_events events.
13861397
13871398
Args:
@@ -1396,7 +1407,7 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
13961407
corresponds to the simulated param of scheduler i at 'event_index'th event.
13971408
"""
13981409

1399-
# This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
1410+
# This scheduler uses `torch.optim.lr_scheduler.LRScheduler` which
14001411
# should be replicated in order to simulate LR values and
14011412
# not perturb original scheduler.
14021413
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -1408,9 +1419,9 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
14081419
torch.save(objs, cache_filepath.as_posix())
14091420

14101421
values = []
1411-
scheduler = cls(schedulers=schedulers, **kwargs) # type: ignore[arg-type]
1422+
scheduler = cls(schedulers=schedulers, **kwargs)
14121423
for i in range(num_events):
1413-
params = [scheduler.get_param() for scheduler in schedulers] # type: ignore[attr-defined]
1424+
params = [scheduler.get_param() for scheduler in schedulers]
14141425
values.append([i] + params)
14151426
scheduler(engine=None)
14161427

tests/ignite/contrib/engines/test_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ def test_asserts_setup_common_training_handlers():
148148
train_sampler = MagicMock(spec=DistributedSampler)
149149
setup_common_training_handlers(trainer, train_sampler=train_sampler)
150150

151-
with pytest.raises(RuntimeError, match=r"This contrib module requires available GPU"):
152-
setup_common_training_handlers(trainer, with_gpu_stats=True)
151+
if not torch.cuda.is_available():
152+
with pytest.raises(RuntimeError, match=r"This contrib module requires available GPU"):
153+
setup_common_training_handlers(trainer, with_gpu_stats=True)
153154

154155
with pytest.raises(TypeError, match=r"Unhandled type of update_function's output."):
155156
trainer = Engine(lambda e, b: None)

tests/ignite/handlers/test_param_scheduler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -621,14 +621,11 @@ def save_lr(engine):
621621

622622
def test_lr_scheduler_asserts():
623623

624-
with pytest.raises(
625-
TypeError, match=r"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler"
626-
):
624+
err_msg = r"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler.(_LRScheduler|LRScheduler)"
625+
with pytest.raises(TypeError, match=err_msg):
627626
LRScheduler(123)
628627

629-
with pytest.raises(
630-
TypeError, match=r"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler"
631-
):
628+
with pytest.raises(TypeError, match=err_msg):
632629
LRScheduler.simulate_values(1, None)
633630

634631

0 commit comments

Comments
 (0)