Skip to content

Commit 206dbf9

Browse files
added ms weights adaptation
1 parent baa88b3 commit 206dbf9

File tree

2 files changed

+54
-15
lines changed

2 files changed

+54
-15
lines changed

detectree2/models/train.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import time
1414
from pathlib import Path
1515
from typing import Any, Dict, List, Optional
16+
from urllib.parse import urlparse
1617

1718
import cv2
1819
import detectron2.data.transforms as T # noqa:N812
@@ -373,6 +374,50 @@ def train(self):
373374
verify_results(self.cfg, self._last_eval_results)
374375
return self._last_eval_results
375376

377+
def resume_or_load(self, resume=True):
378+
"""
379+
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
380+
a `last_checkpoint` file), resume from the file. Resuming means loading all
381+
available states (eg. optimizer and scheduler) and update iteration counter
382+
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
383+
384+
Otherwise, this is considered as an independent training. The method will load model
385+
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
386+
from iteration 0.
387+
388+
Args:
389+
resume (bool): whether to do resume or not
390+
"""
391+
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
392+
if resume and self.checkpointer.has_checkpoint():
393+
# The checkpoint stores the training iteration that just finished, thus we start
394+
# at the next iteration
395+
self.start_iter = self.iter + 1
396+
397+
if self.cfg.MODEL.WEIGHTS:
398+
checkpoint = torch.tensor(
399+
self.checkpointer._load_file(
400+
self.checkpointer.path_manager.get_local_path(
401+
urlparse(self.cfg.MODEL.WEIGHTS)._replace(
402+
query="").geturl()))['model']['backbone.bottom_up.stem.conv1.weight']).to(
403+
self.model.backbone.bottom_up.stem.conv1.weight.device)
404+
input_channels_in_checkpoint = checkpoint.shape[1]
405+
input_channels_in_model = self.model.backbone.bottom_up.stem.conv1.weight.shape[1]
406+
if input_channels_in_checkpoint != input_channels_in_model:
407+
logger = logging.getLogger("detectree2")
408+
if input_channels_in_checkpoint != 3:
409+
logger.warning(
410+
"Input channel modification only works if checkpoint was trained on RGB images (3 channels). The first three channels will be copied and then repeated in the model."
411+
)
412+
logger.warning(
413+
"Mismatch in input channels in checkpoint and model, meaning fvcommon would not have been able to automatically load them. Adjusting weights for 'backbone.bottom_up.stem.conv1.weight' manually."
414+
)
415+
with torch.no_grad():
416+
self.model.backbone.bottom_up.stem.conv1.weight[:, :
417+
input_channels_in_checkpoint] = checkpoint[:, :
418+
input_channels_in_checkpoint]
419+
multiply_conv1_weights(self.model)
420+
376421
@classmethod
377422
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
378423
"""
@@ -964,7 +1009,7 @@ def predictions_on_data(
9641009
json.dump(evaluations, dest)
9651010

9661011

967-
def modify_conv1_weights(model, num_input_channels):
1012+
def multiply_conv1_weights(model):
9681013
"""
9691014
Modify the weights of the first convolutional layer (conv1) to accommodate a different number of input channels.
9701015
@@ -974,12 +1019,12 @@ def modify_conv1_weights(model, num_input_channels):
9741019
9751020
Args:
9761021
model (torch.nn.Module): The model containing the convolutional layer to modify.
977-
num_input_channels (int): The number of input channels for the new conv1 layer.
9781022
9791023
"""
9801024
with torch.no_grad():
9811025
# Retrieve the original weights of the conv1 layer
9821026
old_weights = model.backbone.bottom_up.stem.conv1.weight
1027+
num_input_channels = model.backbone.bottom_up.stem.conv1.weight.shape[1] # The number of input channels
9831028

9841029
# Create a new weight tensor with the desired number of input channels
9851030
# The shape is (out_channels, in_channels, height, width)

docs/source/tutorial.rst

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -399,21 +399,21 @@ The number of bands can be checked with rasterio:
399399
print(f'The raster has {num_bands} bands.')
400400
401401
402-
Due to the additional bands, we must modify the weights of the first convolutional layer (conv1) to accommodate a
403-
different number of input channels. This is done with the ``modify_conv1_weights`` function. The extension of the
404-
``cfg.MODEL.PIXEL_MEAN`` and ``cfg.MODEL.PIXEL_STD`` lists to include the additional bands happens within the
405-
``setup_cfg`` function when ``num_bands`` is set to a value greater than 3. ``imgmode`` should be set to ``"ms"`` to
406-
ensure the correct training routines are called.
402+
Due to the additional bands, the weights of the first convolutional layer (conv1) are modified to accommodate a
403+
different number of input channels. This is automatically done in the case of ``imgmode`` being set to ``"ms"``.
404+
The first three input weights are multiplied across the new bands. The extension of the ``cfg.MODEL.PIXEL_MEAN``
405+
and ``cfg.MODEL.PIXEL_STD`` lists to include the additional bands happens within the ``setup_cfg`` function when
406+
``num_bands`` is set to a value greater than 3. ``imgmode`` should be set to ``"ms"`` to ensure the correct training
407+
routines are called.
407408

408409
.. code-block:: python
409410
410411
from datetime import date
411-
from detectron2.modeling import build_model
412412
import torch.nn as nn
413413
import torch.nn.init as init
414414
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
415415
import numpy as np
416-
from detectree2.models.train import modify_conv1_weights, MyTrainer, setup_cfg
416+
from detectree2.models.train import MyTrainer, setup_cfg
417417
418418
# Good idea to keep track of the date if producing multiple models
419419
today = date.today()
@@ -433,12 +433,6 @@ ensure the correct training routines are called.
433433
max_iter=500000, out_dir=out_dir, resize = "rand_fixed", imgmode="ms",
434434
num_bands= num_bands) # update_model arg can be used to load in trained model
435435
436-
# Build the model
437-
model = build_model(cfg)
438-
439-
# Adjust input layer to accept correct number of channels
440-
modify_conv1_weights(model, num_input_channels=num_bands)
441-
442436
443437
With additional bands, more data is being passed through the network per image so it may be neessary to reduce the
444438
number of images per batch. Only do this is you a getting warnings/errors about memory usage (e.g.

0 commit comments

Comments
 (0)