-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove 'model_.' prefix from onnx model initializers in training #3881
Conversation
model = FuseSofmaxNLLToSoftmaxCE(model) | ||
onnx_model = onnx.load_model_from_string(f.getvalue()) | ||
|
||
# Remove 'model_.' prefix introduced by model wrapper for initializers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary because of class WrapModel
? Is there a way to prevent this "model_." to be added in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"model_" added due to the original model being an attribute "model_" of WrapModel
. To prevent this we need to avoid creating this extra layer of module. Maybe instead directly overwrite the forward method of original model, seems more hacky as the original forward might be called somewhere else..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is good to know that initializer names are changed according to their position in the final model. So in case of there is a loss_fn, initializer names will be model_.model_.xxx? Shall we also have this case handled? Or we could merge 2 wrapper into one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. We should handle the case where both WrapModel and loss_fn are used. Currently the case is not valid due to some issues with WrapModel. Will fix it in a separate PR.
7243b80
to
fd688ad
Compare
d820c62
to
b12a081
Compare
model = FuseSofmaxNLLToSoftmaxCE(model) | ||
onnx_model = onnx.load_model_from_string(f.getvalue()) | ||
|
||
# Remove 'model_.' prefix introduced by model wrapper for initializers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is good to know that initializer names are changed according to their position in the final model. So in case of there is a loss_fn, initializer names will be model_.model_.xxx? Shall we also have this case handled? Or we could merge 2 wrapper into one.
ae2ef5f
to
32beb88
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
onnxruntime/test/testdata/ckpt_mnist.pt is being added to the repo, but it is not used in the PR
onnxruntime/test/testdata/ckpt_mnist.pt is being updated. It was used in an existing mnist checkpoint test case. After this PR, the names for initializers saved are updated, so need to update previous ckpt file. |
After a previous change using wrapper module for onnx export, the onnx model initializer names are all inserted with 'model_.', diverging from the parameter names in origianl pt model. This makes it hard for users to figure out weight names, especially when using
frozen_weights
.This PR removes this prefix, and adds warning message when weight names in
frozen_weights
are not found in the model.