-
Notifications
You must be signed in to change notification settings - Fork 451
Support model in_chans not equal to pre-trained weights in_chans #2324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@microsoft-github-policy-service agree |
Why can't we use |
Because that function will only copy the weights to additional input channels if the first convolution layer of the weights has 3 input channels. Otherwise it raises a NotImplementedError exception and randomly initializes the weights. I linked to the timm implementation in my comment on issue #2289 . |
Will try to review next week, this week is quite busy. Apologies for the wait! |
I recently discovered |
I actually looked into Here's an example: from timm.models import adapt_input_conv
from torchgeo.models import ResNet18_Weights
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO.get_state_dict()
adapt_input_conv(26, weights['conv1.weight'])
There may be other methods in timm which we could use like you mentioned, or may need to add support. |
This PR addresses #2289.
I've made a good start on this and am creating this draft PR to get some feedback before continuing. I have created a load_pretrained function that will copy weights as done by timm, but which supports any number of in_chans of the weights, not just 3. So far I just implemented this new functionality in resnet50 and will expand that to other models if this is a good approach.
With this change, you can create a model like this, in this case the two channels will be copied into 4:
I created a new file torchgeo/models/utils.py where I put the load_pretrained function since I didn't see any suitable existing file. Is there a better place for it?