Skip to content

Support model in_chans not equal to pre-trained weights in_chans #2289

@adamjstewart

Description

Summary

If a user specifies in_chans and weights, and weights.meta['in_chans'] differs from in_chans, the user-specified argument should take precedence and weights should be repeated, similar to how timm handles pre-trained weights.

Rationale

When working on change detection, it is common to take two images and stack them along the channel dimension. However, this makes it impossible to use our pre-trained weights. Ideally, I would like to support something like:

from torchgeo.models import ResNet50_Weights, resnet50

model = resnet50(in_chans=4, weights=ResNet50_Weights.SENTINEL1_ALL_MOCO)

Here, the weights have 2 channels (HH and HV), while the dataset and model will have 4 channels (HH, HV, HH, HV).

Implementation

https://timm.fast.ai/models#Case-2:-When-the-number-of-input-channels-is-not-1 describes the implementation that timm uses. This can be imported as:

from timm.models.helpers import load_pretrained

We should make use of this in all of our model definitions instead of model.load_state_dict.

Alternatives

There is some ongoing work to add a ChangeDetectionTask that may split each image into a separate sample key. However, there will always be models that require images stacked along the channel dimension, so I don't think we can avoid supporting this use case.

Additional information

No response

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

good first issueA good issue for a new contributor to work onmodelsModels and pretrained weightstrainersPyTorch Lightning trainers

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions