@@ -236,7 +236,7 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O
236
236
self .unfreeze_and_add_param_group (pl_module .layer [epoch + 1 ], optimizer )
237
237
238
238
239
- def test_base_finetuning_internal_state (tmpdir ):
239
+ def test_base_finetuning_internal_optimizer_metadata (tmpdir ):
240
240
"""Test the param_groups updates are properly saved within the internal state of the BaseFinetuning Callbacks"""
241
241
242
242
seed_everything (42 )
@@ -265,18 +265,18 @@ def configure_optimizers(self):
265
265
model = FreezeModel ()
266
266
trainer = Trainer (default_root_dir = tmpdir , max_epochs = 5 , limit_train_batches = 1 , callbacks = [cb , chk ])
267
267
trainer .fit (model )
268
- assert len (cb ._internal_state [0 ]) == 6
269
- assert cb ._internal_state [0 ][0 ]["params" ] == ['layer.0.weight' ]
270
- assert cb ._internal_state [0 ][1 ]["params" ] == ['layer.1.weight' , 'layer.1.bias' ]
271
- assert cb ._internal_state [0 ][2 ]["params" ] == ['layer.2.weight' ]
272
- assert cb ._internal_state [0 ][3 ]["params" ] == ['layer.3.weight' , 'layer.3.bias' ]
273
- assert cb ._internal_state [0 ][4 ]["params" ] == ['layer.4.weight' ]
274
- assert cb ._internal_state [0 ][5 ]["params" ] == ['layer.5.weight' , 'layer.5.bias' ]
268
+ assert len (cb ._internal_optimizer_metadata [0 ]) == 6
269
+ assert cb ._internal_optimizer_metadata [0 ][0 ]["params" ] == ['layer.0.weight' ]
270
+ assert cb ._internal_optimizer_metadata [0 ][1 ]["params" ] == ['layer.1.weight' , 'layer.1.bias' ]
271
+ assert cb ._internal_optimizer_metadata [0 ][2 ]["params" ] == ['layer.2.weight' ]
272
+ assert cb ._internal_optimizer_metadata [0 ][3 ]["params" ] == ['layer.3.weight' , 'layer.3.bias' ]
273
+ assert cb ._internal_optimizer_metadata [0 ][4 ]["params" ] == ['layer.4.weight' ]
274
+ assert cb ._internal_optimizer_metadata [0 ][5 ]["params" ] == ['layer.5.weight' , 'layer.5.bias' ]
275
275
276
276
model = FreezeModel ()
277
277
cb = OnEpochLayerFinetuning ()
278
278
trainer = Trainer (max_epochs = 10 , resume_from_checkpoint = chk .last_model_path , callbacks = [cb ])
279
- with pytest .raises (ValueError , match = "loaded state dict has a different number of parameter groups " ):
279
+ with pytest .raises (IndexError , match = "index 6 is out of range " ):
280
280
trainer .fit (model )
281
281
282
282
@@ -365,3 +365,115 @@ def forward(self, x):
365
365
# conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
366
366
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
367
367
assert len (encoder_params ) == 9
368
+
369
+
370
+ class TestCallbacksRestoreCallback (BaseFinetuning ):
371
+
372
+ def freeze_before_training (self , pl_module ):
373
+ self .freeze (pl_module .layer [:3 ])
374
+
375
+ def finetune_function (self , pl_module , epoch , optimizer , opt_idx ):
376
+ if epoch >= 1 :
377
+ self .unfreeze_and_add_param_group (pl_module .layer [epoch - 1 ], optimizer )
378
+
379
+
380
+ class FinetuningBoringModel (BoringModel ):
381
+
382
+ def __init__ (self ):
383
+ super ().__init__ ()
384
+ self .layer = nn .Sequential (nn .Linear (32 , 32 ), nn .Linear (32 , 32 ), nn .Linear (32 , 32 ), nn .Linear (32 , 2 ))
385
+
386
+ def configure_optimizers (self ):
387
+ parameters = filter (lambda x : x .requires_grad , self .parameters ())
388
+ optimizer = torch .optim .SGD (parameters , lr = 0.1 )
389
+ return optimizer
390
+
391
+
392
+ def test_callbacks_restore (tmpdir ):
393
+ """
394
+ Test callbacks restore is called after optimizers have been re-created
395
+ but before optimizer states reload
396
+ """
397
+ chk = ModelCheckpoint (dirpath = tmpdir , save_last = True )
398
+
399
+ model = FinetuningBoringModel ()
400
+ callback = TestCallbacksRestoreCallback ()
401
+
402
+ trainer_kwargs = dict (
403
+ default_root_dir = tmpdir , limit_train_batches = 1 , limit_val_batches = 1 , callbacks = [callback , chk ], max_epochs = 2
404
+ )
405
+
406
+ trainer = Trainer (** trainer_kwargs )
407
+ trainer .fit (model )
408
+
409
+ # only 1 optimizer
410
+ assert len (callback ._internal_optimizer_metadata ) == 1
411
+
412
+ # only 2 param groups
413
+ assert len (callback ._internal_optimizer_metadata [0 ]) == 2
414
+
415
+ # original parameters
416
+ assert callback ._internal_optimizer_metadata [0 ][0 ] == {
417
+ 'lr' : 0.1 ,
418
+ 'momentum' : 0 ,
419
+ 'dampening' : 0 ,
420
+ 'weight_decay' : 0 ,
421
+ 'nesterov' : False ,
422
+ 'params' : ['layer.3.weight' , 'layer.3.bias' ]
423
+ }
424
+
425
+ # new param group
426
+ assert callback ._internal_optimizer_metadata [0 ][1 ] == {
427
+ 'lr' : 0.01 ,
428
+ 'momentum' : 0 ,
429
+ 'dampening' : 0 ,
430
+ 'weight_decay' : 0 ,
431
+ 'nesterov' : False ,
432
+ 'params' : ['layer.0.weight' , 'layer.0.bias' ]
433
+ }
434
+
435
+ trainer_kwargs ["max_epochs" ] = 3
436
+ trainer_kwargs ["resume_from_checkpoint" ] = chk .last_model_path
437
+
438
+ trainer = Trainer (** trainer_kwargs )
439
+ trainer .fit (model )
440
+
441
+
442
+ def test_callbacks_restore_backbone (tmpdir ):
443
+ """
444
+ Test callbacks restore is called after optimizers have been re-created
445
+ but before optimizer states reload
446
+ """
447
+
448
+ class BackboneBoringModel (BoringModel ):
449
+
450
+ def __init__ (self ):
451
+ super ().__init__ ()
452
+ self .layer = nn .Linear (32 , 2 )
453
+ self .backbone = nn .Linear (32 , 32 )
454
+
455
+ def forward (self , x ):
456
+ return self .layer (self .backbone (x ))
457
+
458
+ ckpt = ModelCheckpoint (dirpath = tmpdir , save_last = True )
459
+ trainer = Trainer (
460
+ default_root_dir = tmpdir ,
461
+ limit_train_batches = 1 ,
462
+ limit_val_batches = 1 ,
463
+ max_epochs = 2 ,
464
+ progress_bar_refresh_rate = 0 ,
465
+ callbacks = [ckpt , BackboneFinetuning (unfreeze_backbone_at_epoch = 1 )]
466
+ )
467
+ trainer .fit (BackboneBoringModel ())
468
+
469
+ # initialize a trainer that continues the previous training
470
+ trainer = Trainer (
471
+ default_root_dir = tmpdir ,
472
+ limit_train_batches = 1 ,
473
+ limit_val_batches = 1 ,
474
+ max_epochs = 3 ,
475
+ progress_bar_refresh_rate = 0 ,
476
+ callbacks = BackboneFinetuning (unfreeze_backbone_at_epoch = 1 ),
477
+ resume_from_checkpoint = ckpt .last_model_path
478
+ )
479
+ trainer .fit (BackboneBoringModel ())
0 commit comments