Description
Summary
If a user specifies in_chans
and weights
, and weights.meta['in_chans']
differs from in_chans
, the user-specified argument should take precedence and weights should be repeated, similar to how timm handles pre-trained weights.
Rationale
When working on change detection, it is common to take two images and stack them along the channel dimension. However, this makes it impossible to use our pre-trained weights. Ideally, I would like to support something like:
from torchgeo.models import ResNet50_Weights, resnet50
model = resnet50(in_chans=4, weights=ResNet50_Weights.SENTINEL1_ALL_MOCO)
Here, the weights have 2 channels (HH and HV), while the dataset and model will have 4 channels (HH, HV, HH, HV).
Implementation
https://timm.fast.ai/models#Case-2:-When-the-number-of-input-channels-is-not-1 describes the implementation that timm uses. This can be imported as:
from timm.models.helpers import load_pretrained
We should make use of this in all of our model definitions instead of model.load_state_dict
.
Alternatives
There is some ongoing work to add a ChangeDetectionTask that may split each image into a separate sample key. However, there will always be models that require images stacked along the channel dimension, so I don't think we can avoid supporting this use case.
Additional information
No response
Activity