13
13
import time
14
14
from pathlib import Path
15
15
from typing import Any , Dict , List , Optional
16
+ from urllib .parse import urlparse
16
17
17
18
import cv2
18
19
import detectron2 .data .transforms as T # noqa:N812
@@ -373,6 +374,50 @@ def train(self):
373
374
verify_results (self .cfg , self ._last_eval_results )
374
375
return self ._last_eval_results
375
376
377
+ def resume_or_load (self , resume = True ):
378
+ """
379
+ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
380
+ a `last_checkpoint` file), resume from the file. Resuming means loading all
381
+ available states (eg. optimizer and scheduler) and update iteration counter
382
+ from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
383
+
384
+ Otherwise, this is considered as an independent training. The method will load model
385
+ weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
386
+ from iteration 0.
387
+
388
+ Args:
389
+ resume (bool): whether to do resume or not
390
+ """
391
+ self .checkpointer .resume_or_load (self .cfg .MODEL .WEIGHTS , resume = resume )
392
+ if resume and self .checkpointer .has_checkpoint ():
393
+ # The checkpoint stores the training iteration that just finished, thus we start
394
+ # at the next iteration
395
+ self .start_iter = self .iter + 1
396
+
397
+ if self .cfg .MODEL .WEIGHTS :
398
+ checkpoint = torch .tensor (
399
+ self .checkpointer ._load_file (
400
+ self .checkpointer .path_manager .get_local_path (
401
+ urlparse (self .cfg .MODEL .WEIGHTS )._replace (
402
+ query = "" ).geturl ()))['model' ]['backbone.bottom_up.stem.conv1.weight' ]).to (
403
+ self .model .backbone .bottom_up .stem .conv1 .weight .device )
404
+ input_channels_in_checkpoint = checkpoint .shape [1 ]
405
+ input_channels_in_model = self .model .backbone .bottom_up .stem .conv1 .weight .shape [1 ]
406
+ if input_channels_in_checkpoint != input_channels_in_model :
407
+ logger = logging .getLogger ("detectree2" )
408
+ if input_channels_in_checkpoint != 3 :
409
+ logger .warning (
410
+ "Input channel modification only works if checkpoint was trained on RGB images (3 channels). The first three channels will be copied and then repeated in the model."
411
+ )
412
+ logger .warning (
413
+ "Mismatch in input channels in checkpoint and model, meaning fvcommon would not have been able to automatically load them. Adjusting weights for 'backbone.bottom_up.stem.conv1.weight' manually."
414
+ )
415
+ with torch .no_grad ():
416
+ self .model .backbone .bottom_up .stem .conv1 .weight [:, :
417
+ input_channels_in_checkpoint ] = checkpoint [:, :
418
+ input_channels_in_checkpoint ]
419
+ multiply_conv1_weights (self .model )
420
+
376
421
@classmethod
377
422
def build_evaluator (cls , cfg , dataset_name , output_folder = None ):
378
423
"""
@@ -964,7 +1009,7 @@ def predictions_on_data(
964
1009
json .dump (evaluations , dest )
965
1010
966
1011
967
- def modify_conv1_weights (model , num_input_channels ):
1012
+ def multiply_conv1_weights (model ):
968
1013
"""
969
1014
Modify the weights of the first convolutional layer (conv1) to accommodate a different number of input channels.
970
1015
@@ -974,12 +1019,12 @@ def modify_conv1_weights(model, num_input_channels):
974
1019
975
1020
Args:
976
1021
model (torch.nn.Module): The model containing the convolutional layer to modify.
977
- num_input_channels (int): The number of input channels for the new conv1 layer.
978
1022
979
1023
"""
980
1024
with torch .no_grad ():
981
1025
# Retrieve the original weights of the conv1 layer
982
1026
old_weights = model .backbone .bottom_up .stem .conv1 .weight
1027
+ num_input_channels = model .backbone .bottom_up .stem .conv1 .weight .shape [1 ] # The number of input channels
983
1028
984
1029
# Create a new weight tensor with the desired number of input channels
985
1030
# The shape is (out_channels, in_channels, height, width)
0 commit comments