-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handling mixed precision for dreambooth flux lora training #9565
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Just a single comment.
@linoytsaban could you also give this a look?
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) | ||
pipeline = pipeline.to(accelerator.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we doing it?
We should keep the pipeline model-level components (such as text encoders, VAE, etc.) to a reduced precision no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
text encoders, vae are already in reduced precision :)
As I described in the PR description, this will change dtype of transformers
For mixed precision training, transformer was upcast into fp32 if fp16 training.
But this changes back to fp16, which leads to fp16 unscale error in clip gradient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would something like this work?
#9549 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion! But, in this thread, I was interested in unwanted switch of fp32 into fp16 after validation, not in the computation of T5 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay. Can you provide an example command for us to verify this? Maybe @linoytsaban could give it a try?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@icsl-Jeon a friendly reminder :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be reproduced with any launch commands in the README.
accelerate launch ... --mixed_precision="fp16" ..
I checked the lora precision with
for name, param in transformer.named_parameters():
if 'lora' in name:
print(f"Layer: {name}, dtype: {param.dtype}, requires_grad: {param.requires_grad}")
Hope this help you reproduce!
We could avoid it by running inference in autocast, no? Here's an example:
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM actually!
@linoytsaban could you also review?
@linoytsaban thank you in advance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @icsl-Jeon, LGTM!
Is there any action to be done for merge ? |
Will merge this once the CI run is complete. Thanks a ton! |
… dreambooth as well
* add latent caching + smol updates * update license * replace with free_memory * add --upcast_before_saving to allow saving transformer weights in lower precision * fix models to accumulate * fix mixed precision issue as proposed in #9565 * smol update to readme * style * fix caching latents * style * add tests for latent caching * style * fix latent caching --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
… + small bug fix (#9646) * make lora target modules configurable and change the default * style * make lora target modules configurable and change the default * fix bug when using prodigy and training te * fix mixed precision training as proposed in #9565 for full dreambooth as well * add test and notes * style * address sayaks comments * style * fix test --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Thanks for your contributions! |
Handling mixed precision and add unwarp Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
* add latent caching + smol updates * update license * replace with free_memory * add --upcast_before_saving to allow saving transformer weights in lower precision * fix models to accumulate * fix mixed precision issue as proposed in #9565 * smol update to readme * style * fix caching latents * style * add tests for latent caching * style * fix latent caching --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
… + small bug fix (#9646) * make lora target modules configurable and change the default * style * make lora target modules configurable and change the default * fix bug when using prodigy and training te * fix mixed precision training as proposed in #9565 for full dreambooth as well * add test and notes * style * address sayaks comments * style * fix test --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Handling mixed precision and add unwarp Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
What does this PR do?
Hello 😄 Thank you for the awesome example!
Here, I want to make a PR that helped me train dreambooth LoRA successfully.
dtype
change fortransfermer
afterlog_validation
(especially forfp16
). For mixed training, the original code upcast fp16 to fp32 for mixed precision training. However, after switchingpipeline
dtype inlog_validation
,transformer
dtype returns tofp16
, which can lead to fp16 unscaling error. Actually, I had this problem when I usefp16
option. (For some reason, T5 yielded nan output in bf16, that is why I came to use fp16) train_dreambooth_lora_flux validation RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same #9476dtype
for text encoders for validation. For some reason, the two text encoders were in fp32 in thepipeline
.unwarp
to access theconfig
field oftransformer
. train_dreambooth_lora_flux.py distributed bugs #9161 (comment)Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
linoytsaban @sayakpaul
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.