Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Swin remove avgpooling #1029

Closed
germayneng opened this issue Dec 11, 2021 · 4 comments
Closed

[FEATURE] Swin remove avgpooling #1029

germayneng opened this issue Dec 11, 2021 · 4 comments
Labels
enhancement New feature or request

Comments

@germayneng
Copy link

germayneng commented Dec 11, 2021

Thanks for this awesome library.

i do have a question: i realized there is no way for me to remove the avgpooling layer. I have trained my model and wish to perform some form of analysis and want to remove the avg pooling as stated in the tutorial: tried setting global and avg = ' ' and doesnt seem to work.

Also found a similar thread w.r.t to vit:
#657

Is this behavior expected?

@germayneng germayneng added the enhancement New feature or request label Dec 11, 2021
@germayneng
Copy link
Author

germayneng commented Dec 11, 2021

i am able to resolve this for now by using forward_hooks and iterating through the children() modules, and register the hooks for the module that i am interested in

@rwightman
Copy link
Collaborator

@germayneng I'm working on getting this improved for an end of year realease w/ long awaited pypi update. It touches all transformer and mlp based models so it's a bit of a task and requires lots of testing (may break some downstream usage but I feel it's worth it).

Aside from your approach create_feature_extractor from recent torchvision can be used too and just pass in ['norm'] as return_layers arg... you'd want the to register the leaf modules as per the fx_feature usage so follow this use here https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/fx_features.py#L68-L70. ... it skips some leaf modules that cannot be traced (this method only works with traceable models, but have managed to make almost all timm models work)

rwightman added a commit that referenced this issue Jan 27, 2022
, make pooling interface for transformers and mlp closer to convnets. Still working through some details...
@rwightman
Copy link
Collaborator

@germayneng took awhile, but I've cleaned up foward_features for all vit / mlp models so that the unpooled / non-token selected output is consitently returned from forward_features. Pooling or token selection is done in a new forward_head method, and most models have an attribute that can also enable/disable/change the pooling like CNNs....

@rwightman
Copy link
Collaborator

I may tweak things a bit if any problems arise with the new approach but seems okay so far, I've been training / testing from the branch for a bit now before the recent merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants