-
Notifications
You must be signed in to change notification settings - Fork 470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Patch ORTTrainer's compatibility with DeepSpeed #148
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Hi, I saw that this PR is linked to issue #145 I opened. I do not think this resolves the issue which came from EDIT: I apologize if this is still work in progress and you had a fix in the works. Please ignore this comment if so. Just wanted to make sure I provided enough context for the issue. |
Hi @jambayk , no worries, I opened this PR so that there will be more transparency on the progress. Sorry for the confusion that it might have caused, and thanks a lot for the information on the issue. |
The compatibility of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these warnings @JingyaHuang ! LGTM 🚀 !
optimum/onnxruntime/trainer.py
Outdated
@@ -773,6 +780,13 @@ def evaluation_loop_ort( | |||
) | |||
|
|||
logger.info("[INFO] Exporting the model to ONNX...") | |||
if args.deepspeed and args.fp16: | |||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should check the transformers
version and then raise a warning if the detected version doesn't match the required one for ONNX export on CUDA?
For now, this warning is OK but we'll probably want to revisit this once we bump transformers
with your PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lewtun Yes, exactly. I put it this way since I am not sure in which version of transformers will include it. Will check the transformers version once I have the information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
… into ort-trainer-ds
PR#17183 to support ONNX export on CUDA. Will refactor the code once |
Merge this to enable DeepSpeed with
|
What does this PR do?
Fixes #145 and #146
Contents
deepspeed
fairscale
(simple ✅ dp2/dp3❌)transformers.onnx.export_pytorch
)