Skip to content

[HiDream LoRA] optimizations + small updates #11381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 24, 2025

Conversation

linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Apr 22, 2025

some memory optimizations:

  • add pre-computation of prompt embeddings when custom prompts are used as well
  • add pre-computation of validation prompt as well
  • add --skip_final_inference - to allow to run with validation, but skip the final loading of the pipeline with the lora weights to reduce memory reqs

other changes:

  • update default trained layers
  • save model card even if model is not pushed to hub
  • remove scheduler initialization from code example - not necessary anymore (it's now if the base model's config)

todo:

  • update readme with better defaults

Yarn Art LoRA
Screenshot 2025-04-23 at 14 14 22

training config
import os
os.environ['MODEL_NAME'] = "HiDream-ai/HiDream-I1-Full"
os.environ['DATASET_NAME'] ="Norod78/Yarn-art-style"
os.environ['OUTPUT_DIR'] = "hidream-yarn-art-lora-v2-trainer"

!accelerate launch train_dreambooth_lora_hidream.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --dataset_name=$DATASET_NAME \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --lora_layers="to_k,to_q,to_v,to_out"\
  --instance_prompt="a dog, yarn art style" \
  --validation_prompt="yoda, yarn art style" \
  --caption_column="text" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --use_8bit_adam\
  --rank=16 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=200 \
  --max_train_steps=1000 \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

…sed as well

2. save model card even if model is not pushed to hub
3. remove scheduler initialization from code example - not necessary anymore (it's now if the base model's config)
4. add skip_final_inference - to allow to run with validation, but skip the final loading of the pipeline with the lora weights to reduce memory reqs
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linoytsaban
Copy link
Collaborator Author

@sayakpaul now that I'm thinking about it, even when validation is enabled, since we're not optimizing the text encoders, can't we just pre-encode the validation prompt embeddings as well? and then we don't need to keep or load text encoders for validation at all and simply pass the embeddings to log_validation

@sayakpaul
Copy link
Member

@sayakpaul now that I'm thinking about it, even when validation is enabled, since we're not optimizing the text encoders, can't we just pre-encode the validation prompt embeddings as well? and then we don't need to keep or load text encoders for validation at all and simply pass the embeddings to log_validation

We should. Let's do that.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the optims. LMK if the comments make sense or if anything is unclear.

@linoytsaban
Copy link
Collaborator Author

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Left some more comments. I think we can merge this today. Also would be good to see if we test this on a 40GB GPU.

@linoytsaban
Copy link
Collaborator Author

linoytsaban commented Apr 23, 2025

Thanks @sayakpaul!

I added prints using your snippet (here) - right before caching & and pre-computation of prompt embeddings and straight after deleting them.
freeing the memory seems to work as expected getting us to ~33GB, but we hit ~60GB prior when we move the text encoding pipeline to gpu (this is with resolution==1024)
with offloading -

=== CUDA Memory Stats before caching ===
Current allocated: 32.24 GB
Max allocated: 32.24 GB
Current reserved: 58.29 GB
Max reserved: 58.29 GB

=== CUDA Memory Stats after caching ===
Current allocated: 58.41 GB
Max allocated: 58.41 GB
Current reserved: 60.84 GB
Max reserved: 60.84 GB

=== CUDA Memory Stats after freeing ===
Current allocated: 32.87 GB
Max allocated: 32.87 GB
Current reserved: 33.82 GB
Max reserved: 33.82 GB

without offloading & caching -

=== CUDA Memory Stats before caching ===
Current allocated: 57.77 GB
Max allocated: 57.77 GB
Current reserved: 58.50 GB
Max reserved: 58.50 GB

=== CUDA Memory Stats after caching ===
Current allocated: 58.35 GB
Max allocated: 58.35 GB
Current reserved: 59.11 GB
Max reserved: 59.11 GB

=== CUDA Memory Stats after freeing ===
Current allocated: 32.98 GB
Max allocated: 32.98 GB
Current reserved: 33.39 GB
Max reserved: 33.39 GB

…s only pre-encoded if custom prompts are provided, but should be pre-encoded either way)
@linoytsaban
Copy link
Collaborator Author

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

@linoytsaban
Copy link
Collaborator Author

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks Linoy!

@@ -1140,7 +1131,7 @@ def main(args):
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
target_modules = ["to_k", "to_q", "to_v", "to_out"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can add a comment explaining that including to_out will target all the expert layers.

@linoytsaban
Copy link
Collaborator Author

@bot /style

@linoytsaban linoytsaban merged commit edd7880 into huggingface:main Apr 24, 2025
9 checks passed
@linoytsaban linoytsaban deleted the hidream-followup branch April 28, 2025 12:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants