Skip to content

Commit 07635d0

Browse files
eladsegaltchatonBordaawaelchlipre-commit-ci[bot]
authored
fix restoring finetune callbacks after accelerator setup on training resume (#8501)
Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent abbcfa1 commit 07635d0

File tree

4 files changed

+176
-32
lines changed

4 files changed

+176
-32
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
547547
- Fixed a `TypeError` when wrapping optimizers in the `HorovodPlugin` and running `Trainer.test` ([#7840](https://github.com/PyTorchLightning/pytorch-lightning/pull/7840))
548548

549549

550+
- Fixed `BackboneFinetuning` restoration ([#8501](https://github.com/PyTorchLightning/pytorch-lightning/pull/8501))
551+
552+
553+
550554
## [1.3.8] - 2021-07-01
551555

552556
### Fixed

pytorch_lightning/callbacks/finetuning.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -82,25 +82,33 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
8282
"""
8383

8484
def __init__(self):
85-
self._internal_state: Dict[int, List[Dict[str, Any]]] = {}
85+
self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {}
86+
self._restarting = False
8687

8788
def on_save_checkpoint(
8889
self,
8990
trainer: 'pl.Trainer',
9091
pl_module: 'pl.LightningModule',
9192
checkpoint: Dict[str, Any],
9293
) -> Dict[int, List[Dict[str, Any]]]:
93-
return self._internal_state
94+
return self._internal_optimizer_metadata
9495

9596
def on_load_checkpoint(
9697
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', callback_state: Dict[int, List[Dict[str, Any]]]
9798
) -> None:
98-
self._internal_state = callback_state
99+
self._restarting = True
100+
self._internal_optimizer_metadata = callback_state
101+
102+
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
99103
# restore the param_groups created during the previous training.
100-
named_parameters = dict(pl_module.named_parameters())
101-
for opt_idx, optimizer in enumerate(trainer.optimizers):
102-
param_groups = self.__apply_mapping_to_param_groups(self._internal_state[opt_idx], named_parameters)
103-
optimizer.param_groups = param_groups
104+
if self._restarting:
105+
named_parameters = dict(pl_module.named_parameters())
106+
for opt_idx, optimizer in enumerate(trainer.optimizers):
107+
param_groups = self.__apply_mapping_to_param_groups(
108+
self._internal_optimizer_metadata[opt_idx], named_parameters
109+
)
110+
optimizer.param_groups = param_groups
111+
self._restarting = False
104112

105113
@staticmethod
106114
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
@@ -278,11 +286,13 @@ def _store(
278286
current_param_groups: List[Dict[str, Any]],
279287
) -> None:
280288
mapping = {p: n for n, p in pl_module.named_parameters()}
281-
if opt_idx not in self._internal_state:
282-
self._internal_state[opt_idx] = self.__apply_mapping_to_param_groups(current_param_groups, mapping)
289+
if opt_idx not in self._internal_optimizer_metadata:
290+
self._internal_optimizer_metadata[opt_idx] = self.__apply_mapping_to_param_groups(
291+
current_param_groups, mapping
292+
)
283293
elif num_param_groups != len(current_param_groups):
284294
# save new param_groups possibly created by the users.
285-
self._internal_state[opt_idx].extend(
295+
self._internal_optimizer_metadata[opt_idx].extend(
286296
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
287297
)
288298

@@ -362,15 +372,33 @@ def __init__(
362372
):
363373
super().__init__()
364374

365-
self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch
366-
self.backbone_initial_lr = backbone_initial_lr
367-
self.lambda_func = lambda_func
368-
self.backbone_initial_ratio_lr = backbone_initial_ratio_lr
369-
self.should_align = should_align
370-
self.initial_denom_lr = initial_denom_lr
371-
self.train_bn = train_bn
372-
self.round = round
373-
self.verbose = verbose
375+
self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch
376+
self.lambda_func: Callable = lambda_func
377+
self.backbone_initial_ratio_lr: float = backbone_initial_ratio_lr
378+
self.backbone_initial_lr: Optional[float] = backbone_initial_lr
379+
self.should_align: bool = should_align
380+
self.initial_denom_lr: float = initial_denom_lr
381+
self.train_bn: bool = train_bn
382+
self.verbose: bool = verbose
383+
self.round: int = round
384+
self.previous_backbone_lr: Optional[float] = None
385+
386+
def on_save_checkpoint(
387+
self,
388+
trainer: 'pl.Trainer',
389+
pl_module: 'pl.LightningModule',
390+
checkpoint: Dict[str, Any],
391+
) -> Dict[int, Any]:
392+
return {
393+
"internal_optimizer_metadata": self._internal_optimizer_metadata,
394+
"previous_backbone_lr": self.previous_backbone_lr
395+
}
396+
397+
def on_load_checkpoint(
398+
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', callback_state: Dict[int, List[Dict[str, Any]]]
399+
) -> None:
400+
self.previous_backbone_lr = callback_state["previous_backbone_lr"]
401+
super().on_load_checkpoint(trainer, pl_module, callback_state["internal_optimizer_metadata"])
374402

375403
def on_fit_start(self, trainer, pl_module):
376404
"""
@@ -379,7 +407,7 @@ def on_fit_start(self, trainer, pl_module):
379407
If LightningModule has no nn.Module `backbone` attribute.
380408
"""
381409
if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
382-
return
410+
return super().on_fit_start(trainer, pl_module)
383411
raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
384412

385413
def freeze_before_training(self, pl_module: 'pl.LightningModule'):

tests/callbacks/test_finetuning_callback.py

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O
236236
self.unfreeze_and_add_param_group(pl_module.layer[epoch + 1], optimizer)
237237

238238

239-
def test_base_finetuning_internal_state(tmpdir):
239+
def test_base_finetuning_internal_optimizer_metadata(tmpdir):
240240
"""Test the param_groups updates are properly saved within the internal state of the BaseFinetuning Callbacks"""
241241

242242
seed_everything(42)
@@ -265,18 +265,18 @@ def configure_optimizers(self):
265265
model = FreezeModel()
266266
trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, limit_train_batches=1, callbacks=[cb, chk])
267267
trainer.fit(model)
268-
assert len(cb._internal_state[0]) == 6
269-
assert cb._internal_state[0][0]["params"] == ['layer.0.weight']
270-
assert cb._internal_state[0][1]["params"] == ['layer.1.weight', 'layer.1.bias']
271-
assert cb._internal_state[0][2]["params"] == ['layer.2.weight']
272-
assert cb._internal_state[0][3]["params"] == ['layer.3.weight', 'layer.3.bias']
273-
assert cb._internal_state[0][4]["params"] == ['layer.4.weight']
274-
assert cb._internal_state[0][5]["params"] == ['layer.5.weight', 'layer.5.bias']
268+
assert len(cb._internal_optimizer_metadata[0]) == 6
269+
assert cb._internal_optimizer_metadata[0][0]["params"] == ['layer.0.weight']
270+
assert cb._internal_optimizer_metadata[0][1]["params"] == ['layer.1.weight', 'layer.1.bias']
271+
assert cb._internal_optimizer_metadata[0][2]["params"] == ['layer.2.weight']
272+
assert cb._internal_optimizer_metadata[0][3]["params"] == ['layer.3.weight', 'layer.3.bias']
273+
assert cb._internal_optimizer_metadata[0][4]["params"] == ['layer.4.weight']
274+
assert cb._internal_optimizer_metadata[0][5]["params"] == ['layer.5.weight', 'layer.5.bias']
275275

276276
model = FreezeModel()
277277
cb = OnEpochLayerFinetuning()
278278
trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb])
279-
with pytest.raises(ValueError, match="loaded state dict has a different number of parameter groups"):
279+
with pytest.raises(IndexError, match="index 6 is out of range"):
280280
trainer.fit(model)
281281

282282

@@ -365,3 +365,115 @@ def forward(self, x):
365365
# conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
366366
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
367367
assert len(encoder_params) == 9
368+
369+
370+
class TestCallbacksRestoreCallback(BaseFinetuning):
371+
372+
def freeze_before_training(self, pl_module):
373+
self.freeze(pl_module.layer[:3])
374+
375+
def finetune_function(self, pl_module, epoch, optimizer, opt_idx):
376+
if epoch >= 1:
377+
self.unfreeze_and_add_param_group(pl_module.layer[epoch - 1], optimizer)
378+
379+
380+
class FinetuningBoringModel(BoringModel):
381+
382+
def __init__(self):
383+
super().__init__()
384+
self.layer = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 2))
385+
386+
def configure_optimizers(self):
387+
parameters = filter(lambda x: x.requires_grad, self.parameters())
388+
optimizer = torch.optim.SGD(parameters, lr=0.1)
389+
return optimizer
390+
391+
392+
def test_callbacks_restore(tmpdir):
393+
"""
394+
Test callbacks restore is called after optimizers have been re-created
395+
but before optimizer states reload
396+
"""
397+
chk = ModelCheckpoint(dirpath=tmpdir, save_last=True)
398+
399+
model = FinetuningBoringModel()
400+
callback = TestCallbacksRestoreCallback()
401+
402+
trainer_kwargs = dict(
403+
default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, callbacks=[callback, chk], max_epochs=2
404+
)
405+
406+
trainer = Trainer(**trainer_kwargs)
407+
trainer.fit(model)
408+
409+
# only 1 optimizer
410+
assert len(callback._internal_optimizer_metadata) == 1
411+
412+
# only 2 param groups
413+
assert len(callback._internal_optimizer_metadata[0]) == 2
414+
415+
# original parameters
416+
assert callback._internal_optimizer_metadata[0][0] == {
417+
'lr': 0.1,
418+
'momentum': 0,
419+
'dampening': 0,
420+
'weight_decay': 0,
421+
'nesterov': False,
422+
'params': ['layer.3.weight', 'layer.3.bias']
423+
}
424+
425+
# new param group
426+
assert callback._internal_optimizer_metadata[0][1] == {
427+
'lr': 0.01,
428+
'momentum': 0,
429+
'dampening': 0,
430+
'weight_decay': 0,
431+
'nesterov': False,
432+
'params': ['layer.0.weight', 'layer.0.bias']
433+
}
434+
435+
trainer_kwargs["max_epochs"] = 3
436+
trainer_kwargs["resume_from_checkpoint"] = chk.last_model_path
437+
438+
trainer = Trainer(**trainer_kwargs)
439+
trainer.fit(model)
440+
441+
442+
def test_callbacks_restore_backbone(tmpdir):
443+
"""
444+
Test callbacks restore is called after optimizers have been re-created
445+
but before optimizer states reload
446+
"""
447+
448+
class BackboneBoringModel(BoringModel):
449+
450+
def __init__(self):
451+
super().__init__()
452+
self.layer = nn.Linear(32, 2)
453+
self.backbone = nn.Linear(32, 32)
454+
455+
def forward(self, x):
456+
return self.layer(self.backbone(x))
457+
458+
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
459+
trainer = Trainer(
460+
default_root_dir=tmpdir,
461+
limit_train_batches=1,
462+
limit_val_batches=1,
463+
max_epochs=2,
464+
progress_bar_refresh_rate=0,
465+
callbacks=[ckpt, BackboneFinetuning(unfreeze_backbone_at_epoch=1)]
466+
)
467+
trainer.fit(BackboneBoringModel())
468+
469+
# initialize a trainer that continues the previous training
470+
trainer = Trainer(
471+
default_root_dir=tmpdir,
472+
limit_train_batches=1,
473+
limit_val_batches=1,
474+
max_epochs=3,
475+
progress_bar_refresh_rate=0,
476+
callbacks=BackboneFinetuning(unfreeze_backbone_at_epoch=1),
477+
resume_from_checkpoint=ckpt.last_model_path
478+
)
479+
trainer.fit(BackboneBoringModel())

tests/checkpointing/test_trainer_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,15 @@ def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir):
9191
This test validates that accumulated gradient is properly recomputed and reset on the trainer.
9292
"""
9393

94-
cb = ModelCheckpoint(dirpath=tmpdir, save_last=True)
94+
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
9595
model = BoringModel()
9696
trainer_kwargs = dict(
97-
max_epochs=1, accumulate_grad_batches={0: 2}, callbacks=cb, limit_train_batches=1, limit_val_batches=0
97+
max_epochs=1, accumulate_grad_batches={0: 2}, callbacks=ckpt, limit_train_batches=1, limit_val_batches=0
9898
)
9999
trainer = Trainer(**trainer_kwargs)
100100
trainer.fit(model)
101101

102102
trainer_kwargs['max_epochs'] = 2
103-
trainer_kwargs['resume_from_checkpoint'] = cb.last_model_path
103+
trainer_kwargs['resume_from_checkpoint'] = ckpt.last_model_path
104104
trainer = Trainer(**trainer_kwargs)
105105
trainer.fit(model)

0 commit comments

Comments
 (0)