-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request: Load pre-trained weights (.pth files etc.) into a flux model #2164
Comments
Are you aware of https://fluxml.ai/Flux.jl/stable/saving/#Flux.loadmodel!? I'm not sure it's our place to be including functionality for converting between PyTorch and Flux model structures given the lack of uniformity on the PyTorch side (excepting specific cases like Metalhead.jl where we have a known, limited set of models to map from). |
Yes, I believe we discussed this on Slack yesterday haha. That is what I plan on using for my implementation as it is only for a single model. Just thought I'd post an issue here since I also saw it in several open threads online and in the Pytorch feature parity document in this repo. |
I would say it's partially covered by the "We should expose the possibility to load pretrained weights" point under "PyTorch Extras" in #1431. As for more general solutions, were someone to come up with a general Dict -> nested struct transformation which works with most PyTorch models, we could consider depending/integrating/advertising it on the Flux side. |
I'll let you know if what I come up with is general enough. Flux should have all the same layer types & hyperparameters as PyTorch correct (with different names)? |
Not necessarily, which is another reason thisi s difficult to generalize. In general we try to keep to close to PyTorch if there's no good reason to diverge, but that's not a hard rule. |
Some scripts for porting weights can be found in the Metalhead repo https://github.com/FluxML/Metalhead.jl/tree/master/scripts |
Motivation and description
I think it would be useful to load pre-trained weights from PyTorch or Tensorflow. I'm sure this has been discussed before (e.g., in the PyTorch feature parity doc), but I could not an open issue on this.
Possible Implementation
My current work around takes the .pth file and opens it with Pickle.jl. I am still working to parse the resulting dictionary to create a Flux model. It would make my life much easier if there was an associated flux function to just load pre-trained weights and evaluate a model.
The text was updated successfully, but these errors were encountered: