Skip to content

Commit 03577ac

Browse files
Type fix (#2616)
* Fixed all the type-errors according to the nightly version of pytorch * Skipped mypy tests for the stable pytorch version * Changed the name of the variable to avoid assignment error in mypy
1 parent 7931dd9 commit 03577ac

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ jobs:
100100
bash ./tests/run_code_style.sh lint
101101
102102
- name: Run Mypy
103-
if: ${{ matrix.os == 'ubuntu-latest' }}
103+
if: ${{ matrix.os == 'ubuntu-latest' && matrix.pytorch-channel == 'pytorch-nightly'}}
104104
run: |
105105
bash ./tests/run_code_style.sh mypy
106106

ignite/handlers/lr_finder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def _run(
122122
start_lr = optimizer.param_groups[0]["lr"]
123123
# Initialize the proper learning rate policy
124124
if step_mode.lower() == "exp":
125-
start_lr = [start_lr] * len(optimizer.param_groups) # type: ignore
126-
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr, end_lr, num_iter))
125+
start_lr_list = [start_lr] * len(optimizer.param_groups)
126+
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr, num_iter))
127127
else:
128128
self._lr_schedule = PiecewiseLinear(
129129
optimizer, param_name="lr", milestones_values=[(0, start_lr), (num_iter, end_lr)]
@@ -487,7 +487,7 @@ class _ExponentialLR(_LRScheduler):
487487
488488
"""
489489

490-
def __init__(self, optimizer: Optimizer, start_lr: float, end_lr: float, num_iter: int, last_epoch: int = -1):
490+
def __init__(self, optimizer: Optimizer, start_lr: List[float], end_lr: float, num_iter: int, last_epoch: int = -1):
491491
self.end_lr = end_lr
492492
self.num_iter = num_iter
493493
super(_ExponentialLR, self).__init__(optimizer, last_epoch)
@@ -496,6 +496,6 @@ def __init__(self, optimizer: Optimizer, start_lr: float, end_lr: float, num_ite
496496
self.base_lrs = start_lr
497497

498498
def get_lr(self) -> List[float]: # type: ignore
499-
curr_iter = self.last_epoch + 1 # type: ignore[attr-defined]
499+
curr_iter = self.last_epoch + 1
500500
r = curr_iter / self.num_iter
501-
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] # type: ignore[attr-defined]
501+
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]

ignite/handlers/param_scheduler.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ def __init__(
851851

852852
self.lr_scheduler = lr_scheduler
853853
super(LRScheduler, self).__init__(
854-
optimizer=self.lr_scheduler.optimizer, # type: ignore[attr-defined]
854+
optimizer=self.lr_scheduler.optimizer,
855855
param_name="lr",
856856
save_history=save_history,
857857
)
@@ -861,13 +861,13 @@ def __init__(
861861
"instead of Events.ITERATION_STARTED to make sure to use "
862862
"the first lr value from the optimizer, otherwise it is will be skipped"
863863
)
864-
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]
864+
self.lr_scheduler.last_epoch += 1
865865

866866
self._state_attrs += ["lr_scheduler"]
867867

868868
def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
869869
super(LRScheduler, self).__call__(engine, name)
870-
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]
870+
self.lr_scheduler.last_epoch += 1
871871

872872
def get_param(self) -> Union[float, List[float]]:
873873
"""Method to get current optimizer's parameter value"""
@@ -908,7 +908,7 @@ def simulate_values( # type: ignore[override]
908908
cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
909909
obj = {
910910
"lr_scheduler": lr_scheduler.state_dict(),
911-
"optimizer": lr_scheduler.optimizer.state_dict(), # type: ignore[attr-defined]
911+
"optimizer": lr_scheduler.optimizer.state_dict(),
912912
}
913913
torch.save(obj, cache_filepath.as_posix())
914914

@@ -921,7 +921,7 @@ def simulate_values( # type: ignore[override]
921921

922922
obj = torch.load(cache_filepath.as_posix())
923923
lr_scheduler.load_state_dict(obj["lr_scheduler"])
924-
lr_scheduler.optimizer.load_state_dict(obj["optimizer"]) # type: ignore[attr-defined]
924+
lr_scheduler.optimizer.load_state_dict(obj["optimizer"])
925925

926926
return values
927927

@@ -1403,7 +1403,7 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
14031403
cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
14041404
objs = {f"lr_scheduler_{i}": s.state_dict() for i, s in enumerate(schedulers)}
14051405
# all schedulers should be related to the same optimizer
1406-
objs["optimizer"] = schedulers[0].optimizer.state_dict() # type: ignore[attr-defined]
1406+
objs["optimizer"] = schedulers[0].optimizer.state_dict()
14071407

14081408
torch.save(objs, cache_filepath.as_posix())
14091409

@@ -1417,7 +1417,7 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
14171417
objs = torch.load(cache_filepath.as_posix())
14181418
for i, s in enumerate(schedulers):
14191419
s.load_state_dict(objs[f"lr_scheduler_{i}"])
1420-
s.optimizer.load_state_dict(objs["optimizer"]) # type: ignore[attr-defined]
1420+
s.optimizer.load_state_dict(objs["optimizer"])
14211421

14221422
return values
14231423

@@ -1561,8 +1561,8 @@ def get_param(self) -> Union[float, List[float]]:
15611561
def _reduce_lr(self, epoch: int) -> None:
15621562
for i, param_group in enumerate(self.optimizer_param_groups):
15631563
old_lr = float(param_group["lr"])
1564-
new_lr = max(old_lr * self.scheduler.factor, self.scheduler.min_lrs[i]) # type: ignore[attr-defined]
1565-
if old_lr - new_lr > self.scheduler.eps: # type: ignore[attr-defined]
1564+
new_lr = max(old_lr * self.scheduler.factor, self.scheduler.min_lrs[i])
1565+
if old_lr - new_lr > self.scheduler.eps:
15661566
param_group["lr"] = new_lr
15671567

15681568
@classmethod

0 commit comments

Comments
 (0)