-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-pretrained weight support - initial API + ResNet50 #4610
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go!
meta={ | ||
**_common_meta, | ||
"recipe": "https://github.com/pytorch/vision/issues/3995", | ||
"acc@1": 80.352, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got more still in the oven 🤞
}, | ||
) | ||
ImageNet1K_RefV2 = WeightEntry( | ||
url="https://download.pytorch.org/models/resnet50-tmp.pth", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the plan here, to re-upload at some point in the future? Also, how do we plan on keeping the names for the checkpoint files manageable, just rely on the sha256 to differentiate them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the plan here, to re-upload at some point in the future?
Yes indeed, I still got models being trained so I expect that the weights will change. Just wanted to add something here so that we can see how multiple weights work.
how do we plan on keeping the names for the checkpoint files manageable
Good point. This is why I didn't add the sha256 on the temporary model. I don't want to fill the bucket with mess. I expect there will be one final set of weights added here at the end of all training. Since we are on prototype, I consider I can change it at any time.
just rely on the sha256 to differentiate them?
I don't have a preference. We could introduce more descriptive names (perhaps using the same string as the enum name?) or just rely on sha256.
"acc@5": 92.862, | ||
}, | ||
) | ||
ImageNet1K_RefV2 = WeightEntry( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming will be important here. ImageNet1K_RefV2
sounds good for a v1, but we should have a webpage in the doc which will break this down nicely. Maybe something to keep in mind, an easy way to gather this information automatically to facilitate generating the documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are right. The plan is to change this once the recipe is finalized. We will need to update this, along with the URL of the recipe (currently pointing to the issue that I got open).
|
||
_common_meta = { | ||
"size": (224, 224), | ||
"categories": list(range(1000)), # TODO: torchvision.prototype.datasets.find("ImageNet").info.categories |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pmeier Let me know when you got the ImageNet category class so that I can replace it here.
return F.convert_image_dtype(img, self.dtype) | ||
|
||
|
||
class ImageNetEval: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e9ff413
to
104a05c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just questions from me :)
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a future plan to allow users to get a pretrained model, without needing to manually instanciate a ResNet50Weights
weights object? E.g. something like resnet50(weights='pretrained')
would always produce the "default pretrained weights" (which could be e.g. the latest version of the weights, or something else)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To give more context to my question: this isn't just about convenience, but also regarding torchhub.
It'd be cool if we could still load models from torchhub using just torch.load('pytorch/vision', 'resnet50', pretrained=SOMETHING)
where SOMETHING doesn't have to be a custom torchvision class like ResNet50Weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, definitely a feature we want to add. See here for a prototype of exactly what you said. I choose not to include it in this prototype to go for the absolute minimal implementation and give time to review the other RFC as a whole.
Concerning your comment to not require access to the Enum object, I think you are hinting TorchHub here. For this use-case if you pass the string name of the enum value it will build it for you. See this. Could you have one more look and let me know if this works for starters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we posted a reply at the same minute (again!). You are currently able to instantiate a model as follows as well:
model = P.models.resnet50(weights="ImageNet1K_RefV2")
Thus I believe on torchhub you will do:
torch.load('pytorch/vision', 'resnet50', weights="ImageNet1K_RefV2")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that makes sense
) Summary: * Adding lightweight API for models. * Adding resnet50. * Fix preset * Add fake categories. * Fixing mypy. * Add string=>weight conversion support on Enums. * Temporarily hardcoding imagenet categories. * Minor refactoring. Reviewed By: fmassa Differential Revision: D31649970 fbshipit-source-id: b4908da7be972c0a19949e75d61f2051e785494c
* Adding lightweight API for models. * Adding resnet50. * Fix preset * Add fake categories. * Fixing mypy. * Add string=>weight conversion support on Enums. * Temporarily hardcoding imagenet categories. * Minor refactoring.
* Adding lightweight API for models. * Adding resnet50. * Fix preset * Add fake categories. * Fixing mypy. * Add string=>weight conversion support on Enums. * Temporarily hardcoding imagenet categories. * Minor refactoring.
Resolves #4671
Adds multi-pretrained weight support on the existing model builders of TorchVision.
Example usage:
cc @datumbox @pmeier @bjuncek