Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove 'model_.' prefix from onnx model initializers in training #3881

Merged
merged 8 commits into from
May 20, 2020

Conversation

BowenBao
Copy link
Contributor

@BowenBao BowenBao commented May 8, 2020

After a previous change using wrapper module for onnx export, the onnx model initializer names are all inserted with 'model_.', diverging from the parameter names in origianl pt model. This makes it hard for users to figure out weight names, especially when using frozen_weights.

This PR removes this prefix, and adds warning message when weight names in frozen_weights are not found in the model.

@BowenBao BowenBao added the training issues related to ONNX Runtime training; typically submitted using template label May 8, 2020
@BowenBao BowenBao requested a review from a team as a code owner May 8, 2020 23:07
@BowenBao BowenBao changed the title Remove 'model_.' prefix for onnx model initializers in training Remove 'model_.' prefix from onnx model initializers in training May 8, 2020
orttraining/orttraining/python/ort_trainer.py Outdated Show resolved Hide resolved
onnxruntime/test/python/onnxruntime_test_ort_trainer.py Outdated Show resolved Hide resolved
model = FuseSofmaxNLLToSoftmaxCE(model)
onnx_model = onnx.load_model_from_string(f.getvalue())

# Remove 'model_.' prefix introduced by model wrapper for initializers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary because of class WrapModel? Is there a way to prevent this "model_." to be added in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"model_" added due to the original model being an attribute "model_" of WrapModel. To prevent this we need to avoid creating this extra layer of module. Maybe instead directly overwrite the forward method of original model, seems more hacky as the original forward might be called somewhere else..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is good to know that initializer names are changed according to their position in the final model. So in case of there is a loss_fn, initializer names will be model_.model_.xxx? Shall we also have this case handled? Or we could merge 2 wrapper into one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. We should handle the case where both WrapModel and loss_fn are used. Currently the case is not valid due to some issues with WrapModel. Will fix it in a separate PR.

@BowenBao BowenBao force-pushed the bowbao/remove_initializer_prefix branch from 7243b80 to fd688ad Compare May 12, 2020 20:38
@BowenBao BowenBao force-pushed the bowbao/remove_initializer_prefix branch from d820c62 to b12a081 Compare May 14, 2020 22:21
model = FuseSofmaxNLLToSoftmaxCE(model)
onnx_model = onnx.load_model_from_string(f.getvalue())

# Remove 'model_.' prefix introduced by model wrapper for initializers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is good to know that initializer names are changed according to their position in the final model. So in case of there is a loss_fn, initializer names will be model_.model_.xxx? Shall we also have this case handled? Or we could merge 2 wrapper into one.

orttraining/orttraining/python/pt_patch.py Outdated Show resolved Hide resolved
@BowenBao BowenBao force-pushed the bowbao/remove_initializer_prefix branch from ae2ef5f to 32beb88 Compare May 18, 2020 21:06
Copy link
Contributor

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onnxruntime/test/testdata/ckpt_mnist.pt is being added to the repo, but it is not used in the PR

orttraining/orttraining/python/pt_patch.py Outdated Show resolved Hide resolved
orttraining/orttraining/python/ort_trainer.py Show resolved Hide resolved
@BowenBao
Copy link
Contributor Author

onnxruntime/test/testdata/ckpt_mnist.pt is being updated. It was used in an existing mnist checkpoint test case. After this PR, the names for initializers saved are updated, so need to update previous ckpt file.

@BowenBao BowenBao merged commit 0a5395b into master May 20, 2020
@BowenBao BowenBao deleted the bowbao/remove_initializer_prefix branch May 20, 2020 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants