10
10
from typing import Any , cast , Dict , List , Mapping , Optional , Sequence , Tuple , Type , Union
11
11
12
12
import torch
13
- from torch .optim .lr_scheduler import _LRScheduler , ReduceLROnPlateau
13
+ from torch .optim .lr_scheduler import ReduceLROnPlateau
14
14
from torch .optim .optimizer import Optimizer
15
15
16
+ # https://github.com/pytorch/ignite/issues/2773
17
+ try :
18
+ from torch .optim .lr_scheduler import LRScheduler as PyTorchLRScheduler
19
+ except ImportError :
20
+ from torch .optim .lr_scheduler import _LRScheduler as PyTorchLRScheduler
21
+
16
22
from ignite .engine import Engine
17
23
18
24
@@ -838,14 +844,15 @@ def print_lr():
838
844
839
845
def __init__ (
840
846
self ,
841
- lr_scheduler : _LRScheduler ,
847
+ lr_scheduler : PyTorchLRScheduler ,
842
848
save_history : bool = False ,
843
849
use_legacy : bool = False ,
844
850
):
845
851
846
- if not isinstance (lr_scheduler , _LRScheduler ):
852
+ if not isinstance (lr_scheduler , PyTorchLRScheduler ):
847
853
raise TypeError (
848
- "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
854
+ "Argument lr_scheduler should be a subclass of "
855
+ f"torch.optim.lr_scheduler.{ PyTorchLRScheduler .__name__ } , "
849
856
f"but given { type (lr_scheduler )} "
850
857
)
851
858
@@ -882,7 +889,7 @@ def get_param(self) -> Union[float, List[float]]:
882
889
883
890
@classmethod
884
891
def simulate_values ( # type: ignore[override]
885
- cls , num_events : int , lr_scheduler : _LRScheduler , ** kwargs : Any
892
+ cls , num_events : int , lr_scheduler : PyTorchLRScheduler , ** kwargs : Any
886
893
) -> List [List [int ]]:
887
894
"""Method to simulate scheduled values during num_events events.
888
895
@@ -894,13 +901,14 @@ def simulate_values( # type: ignore[override]
894
901
event_index, value
895
902
"""
896
903
897
- if not isinstance (lr_scheduler , _LRScheduler ):
904
+ if not isinstance (lr_scheduler , PyTorchLRScheduler ):
898
905
raise TypeError (
899
- "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
906
+ "Argument lr_scheduler should be a subclass of "
907
+ f"torch.optim.lr_scheduler.{ PyTorchLRScheduler .__name__ } , "
900
908
f"but given { type (lr_scheduler )} "
901
909
)
902
910
903
- # This scheduler uses `torch.optim.lr_scheduler._LRScheduler ` which
911
+ # This scheduler uses `torch.optim.lr_scheduler.LRScheduler ` which
904
912
# should be replicated in order to simulate LR values and
905
913
# not perturb original scheduler.
906
914
with tempfile .TemporaryDirectory () as tmpdirname :
@@ -926,7 +934,7 @@ def simulate_values( # type: ignore[override]
926
934
927
935
928
936
def create_lr_scheduler_with_warmup (
929
- lr_scheduler : Union [ParamScheduler , _LRScheduler ],
937
+ lr_scheduler : Union [ParamScheduler , PyTorchLRScheduler ],
930
938
warmup_start_value : float ,
931
939
warmup_duration : int ,
932
940
warmup_end_value : Optional [float ] = None ,
@@ -995,10 +1003,11 @@ def print_lr():
995
1003
996
1004
.. versionadded:: 0.4.5
997
1005
"""
998
- if not isinstance (lr_scheduler , (ParamScheduler , _LRScheduler )):
1006
+ if not isinstance (lr_scheduler , (ParamScheduler , PyTorchLRScheduler )):
999
1007
raise TypeError (
1000
- "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler or "
1001
- f"ParamScheduler, but given { type (lr_scheduler )} "
1008
+ "Argument lr_scheduler should be a subclass of "
1009
+ f"torch.optim.lr_scheduler.{ PyTorchLRScheduler .__name__ } or ParamScheduler, "
1010
+ f"but given { type (lr_scheduler )} "
1002
1011
)
1003
1012
1004
1013
if not isinstance (warmup_duration , numbers .Integral ):
@@ -1018,7 +1027,7 @@ def print_lr():
1018
1027
1019
1028
milestones_values = [(0 , warmup_start_value ), (warmup_duration - 1 , param_group_warmup_end_value )]
1020
1029
1021
- if isinstance (lr_scheduler , _LRScheduler ):
1030
+ if isinstance (lr_scheduler , PyTorchLRScheduler ):
1022
1031
init_lr = param_group ["lr" ]
1023
1032
if init_lr != param_group_warmup_end_value :
1024
1033
milestones_values .append ((warmup_duration , init_lr ))
@@ -1054,7 +1063,7 @@ def print_lr():
1054
1063
schedulers = [
1055
1064
warmup_scheduler ,
1056
1065
lr_scheduler ,
1057
- ] # type: List[Union[ParamScheduler, ParamGroupScheduler, _LRScheduler ]]
1066
+ ] # type: List[Union[ParamScheduler, ParamGroupScheduler, PyTorchLRScheduler ]]
1058
1067
durations = [milestones_values [- 1 ][0 ] + 1 ]
1059
1068
combined_scheduler = ConcatScheduler (schedulers , durations = durations , save_history = save_history )
1060
1069
@@ -1381,7 +1390,9 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1381
1390
s .load_state_dict (sd )
1382
1391
1383
1392
@classmethod
1384
- def simulate_values (cls , num_events : int , schedulers : List [_LRScheduler ], ** kwargs : Any ) -> List [List [int ]]:
1393
+ def simulate_values (
1394
+ cls , num_events : int , schedulers : List [ParamScheduler ], ** kwargs : Any
1395
+ ) -> List [List [Union [List [float ], float , int ]]]:
1385
1396
"""Method to simulate scheduled values during num_events events.
1386
1397
1387
1398
Args:
@@ -1396,7 +1407,7 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
1396
1407
corresponds to the simulated param of scheduler i at 'event_index'th event.
1397
1408
"""
1398
1409
1399
- # This scheduler uses `torch.optim.lr_scheduler._LRScheduler ` which
1410
+ # This scheduler uses `torch.optim.lr_scheduler.LRScheduler ` which
1400
1411
# should be replicated in order to simulate LR values and
1401
1412
# not perturb original scheduler.
1402
1413
with tempfile .TemporaryDirectory () as tmpdirname :
@@ -1408,9 +1419,9 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
1408
1419
torch .save (objs , cache_filepath .as_posix ())
1409
1420
1410
1421
values = []
1411
- scheduler = cls (schedulers = schedulers , ** kwargs ) # type: ignore[arg-type]
1422
+ scheduler = cls (schedulers = schedulers , ** kwargs )
1412
1423
for i in range (num_events ):
1413
- params = [scheduler .get_param () for scheduler in schedulers ] # type: ignore[attr-defined]
1424
+ params = [scheduler .get_param () for scheduler in schedulers ]
1414
1425
values .append ([i ] + params )
1415
1426
scheduler (engine = None )
1416
1427
0 commit comments