Skip to content

Support model in_chans not equal to pre-trained weights in_chans #2324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

keves1
Copy link
Contributor

@keves1 keves1 commented Sep 27, 2024

This PR addresses #2289.

I've made a good start on this and am creating this draft PR to get some feedback before continuing. I have created a load_pretrained function that will copy weights as done by timm, but which supports any number of in_chans of the weights, not just 3. So far I just implemented this new functionality in resnet50 and will expand that to other models if this is a good approach.

With this change, you can create a model like this, in this case the two channels will be copied into 4:

model = resnet50(weights=ResNet50_Weights.SENTINEL1_ALL_MOCO, in_chans=4)

I created a new file torchgeo/models/utils.py where I put the load_pretrained function since I didn't see any suitable existing file. Is there a better place for it?

@github-actions github-actions bot added the models Models and pretrained weights label Sep 27, 2024
@keves1
Copy link
Contributor Author

keves1 commented Sep 27, 2024

@microsoft-github-policy-service agree

@adamjstewart
Copy link
Collaborator

I have created a load_pretrained function that will copy weights as done by timm, but which supports any number of in_chans of the weights, not just 3.

Why can't we use timm.models.helpers.load_pretrained instead of writing our own custom code?

@adamjstewart adamjstewart modified the milestones: 0.7.0, 0.6.1 Sep 28, 2024
@keves1
Copy link
Contributor Author

keves1 commented Sep 28, 2024

Why can't we use timm.models.helpers.load_pretrained instead of writing our own custom code?

Because that function will only copy the weights to additional input channels if the first convolution layer of the weights has 3 input channels. Otherwise it raises a NotImplementedError exception and randomly initializes the weights. I linked to the timm implementation in my comment on issue #2289 .

@adamjstewart
Copy link
Collaborator

Will try to review next week, this week is quite busy. Apologies for the wait!

@adamjstewart adamjstewart modified the milestones: 0.6.1, 0.6.2 Oct 10, 2024
@adamjstewart adamjstewart modified the milestones: 0.6.2, 0.6.3 Dec 5, 2024
@adamjstewart
Copy link
Collaborator

I recently discovered timm.models.adapt_input_conv, which makes it easy to change the number of channels without losing pretrained weights. See #2602 for an example. There seems to be a lot of other builtin stuff in timm we might be able to make use of. If it doesn't support everything we need, I would be willing to try to add support for it in timm.

@keves1
Copy link
Contributor Author

keves1 commented Feb 24, 2025

I actually looked into timm.models.adapt_input_conv (this is what is called by timm.models.helpers.load_pretrained) and it only works if the if the first convolution layer of the weights has 3 input channels (see my comment above). So we couldn't use adapt_input_conv with weights trained on multispectral or other non 3 channel imagery.

Here's an example:

from timm.models import adapt_input_conv
from torchgeo.models import ResNet18_Weights

weights = ResNet18_Weights.SENTINEL2_ALL_MOCO.get_state_dict()
adapt_input_conv(26, weights['conv1.weight'])
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[17], line 1
----> 1 adapt_input_conv(26, weights['conv1.weight'])

File ~/miniconda3/envs/torchgeo-env/lib/python3.11/site-packages/timm/models/_manipulate.py:270, in adapt_input_conv(in_chans, conv_weight)
    268 elif in_chans != 3:
    269     if I != 3:
--> 270         raise NotImplementedError('Weight format not supported by conversion.')
    271     else:
    272         # NOTE this strategy should be better than random init, but there could be other combinations of
    273         # the original RGB input layer weights that'd work better for specific cases.
    274         repeat = int(math.ceil(in_chans / 3))

NotImplementedError: Weight format not supported by conversion.

There may be other methods in timm which we could use like you mentioned, or may need to add support.
I'm focusing on trying to get a PR ready for adding the Autoregression trainer, so I won't be working on this more in the short term.

@adamjstewart adamjstewart removed this from the 0.6.3 milestone Mar 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Models and pretrained weights
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants