-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] Support original format loras for HunyuanVideo #10376
Conversation
@sayakpaul @DN6 Not comfortable adding integrations tests with any of the existing original-format loras on the Hub/CivitAI for HunyuanVideo. They are either Anime/NSFW/Person-related loras which we usually don't use for testing. I'll work on training a style lora over the next few days and convert it to original format for integration tests. Does that sound good? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Okay to merge for me after an integration test!
Subscribing to this PR |
Hi, does Diffusers have a training script available for training HunyuanVideo LoRAs? |
We do, thanks to @a-r-r-o-w 🐐 |
is this pr stalled just because there are no loras that are (very subjectively) deemed suitable for integration test? |
@vladmandic No, sorry that's not it. I didn't find the time to get back to this yet. We have loras suitable for integration tests now, so will complete the PR tomorrow |
Okay, so there's a problem with the loras we trained 🤣 We didn't train the In the interest of time, I'm just going to use this checkpoint for the integration tests: https://huggingface.co/Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1/ partial conversion scriptimport argparse
import os
from collections import defaultdict
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
def convert_diffusers_to_hunyuan_video_lora(diffusers_state_dict):
converted_state_dict = {k: diffusers_state_dict.pop(k) for k in list(diffusers_state_dict.keys())}
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
".double_blocks": ".transformer_blocks",
".single_blocks": ".single_transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
# txt_in
"individual_token_refiner.blocks": "token_refiner.refiner_blocks",
"adaLN_modulation.1": "norm_out.linear",
"txt_in": "context_embedder",
"t_embedder.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"t_embedder.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"c_embedder": "time_text_embed.text_embedder",
"mlp": "ff",
}
TRANSFORMER_KEYS_RENAME_DICT_REVERSE = {v: k for k, v in TRANSFORMER_KEYS_RENAME_DICT.items()}
# TRANSFORMER_SPECIAL_KEYS_REMAP_BACK = {
# "norm_out.linear": remap_norm_scale_shift_back_,
# "context_embedder": remap_txt_in_back_,
# "attn.to_q": remap_img_attn_qkv_back_,
# "attn.add_q_proj": remap_txt_attn_qkv_back_,
# "single_transformer_blocks": remap_single_transformer_blocks_back_,
# }
for key in list(converted_state_dict.keys()):
if "norm_out.linear" in key:
weight = converted_state_dict.pop(key)
scale, shift = weight.chunk(2, dim=0)
new_weight = torch.cat([shift, scale], dim=0)
converted_state_dict[key] = new_weight
if "to_q" in key:
if "single_transformer_blocks" in key:
to_q = converted_state_dict.pop(key)
to_k = converted_state_dict.pop(key.replace("to_q", "to_k"))
to_v = converted_state_dict.pop(key.replace("to_q", "to_v"))
to_out = converted_state_dict.pop(key.replace("attn.to_q", "proj_mlp"))
rename_attn_key = "img_attn_qkv"
if "lora_A" in key:
converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = to_q
else:
qkv_mlp = torch.cat([to_q, to_k, to_v, to_out], dim=0)
converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = qkv_mlp
else:
to_q = converted_state_dict.pop(key)
to_k = converted_state_dict.pop(key.replace("to_q", "to_k"))
to_v = converted_state_dict.pop(key.replace("to_q", "to_v"))
if "token_refiner" in key:
rename_attn_key = "self_attn_qkv"
if "lora_A" in key:
converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = to_q
else:
qkv = torch.cat([to_q, to_k, to_v], dim=0)
converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = qkv
else:
rename_attn_key = "img_attn_qkv"
if "lora_A" in key:
converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = to_q
else:
qkv = torch.cat([to_q, to_k, to_v], dim=0)
converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = qkv
if "add_q_proj" in key:
to_q = converted_state_dict.pop(key)
to_k = converted_state_dict.pop(key.replace("add_q_proj", "add_k_proj"))
to_v = converted_state_dict.pop(key.replace("add_q_proj", "add_v_proj"))
rename_attn_key = "txt_attn_qkv"
if "lora_A" in key:
converted_state_dict[key.replace("attn.add_q_proj", rename_attn_key)] = to_q
else:
qkv = torch.cat([to_q, to_k, to_v], dim=0)
converted_state_dict[key.replace("attn.add_q_proj", rename_attn_key)] = qkv
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT_REVERSE.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
# Remove "transformer." prefix
for key in list(converted_state_dict.keys()):
if key.startswith("transformer."):
converted_state_dict[key[len("transformer."):]] = converted_state_dict.pop(key)
# Add back "diffusion_model." prefix
for key in list(converted_state_dict.keys()):
converted_state_dict[f"diffusion_model.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path_or_name", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
if args.ckpt_path.endswith(".pt"):
diffusers_state_dict = torch.load(args.ckpt_path, map_location="cpu", weights_only=True)
elif args.ckpt_path.endswith(".safetensors"):
diffusers_state_dict = load_file(args.ckpt_path)
original_format_state_dict = convert_diffusers_to_hunyuan_video_lora(diffusers_state_dict)
output_path_or_name = Path(args.output_path_or_name)
if output_path_or_name.as_posix().endswith(".safetensors"):
os.makedirs(output_path_or_name.parent, exist_ok=True)
save_file(original_format_state_dict, output_path_or_name)
else:
os.makedirs(output_path_or_name, exist_ok=True)
output_path_or_name = output_path_or_name / "pytorch_lora_weights.safetensors"
save_file(original_format_state_dict, output_path_or_name) |
35d62aa
to
8be9180
Compare
It seems like we don't have anything yet to test all kinds of conversion paths. Not really a problem IMO. The loras available so far seem to only train the |
Fix for styling error: #10478 |
This is perfectly fine and should be the way to go. We cannot test all kinds of community LoRAs from the get-go and something we have been following for a while. Totally okay with me! |
@@ -4239,10 +4243,9 @@ def save_lora_weights( | |||
safe_serialization=safe_serialization, | |||
) | |||
|
|||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could leverage the CogVideoX fuse_lora for the "Copy" statement, no? If so, I'd prefer that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really because we have a hunyuan specific example here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we follow "Copied from ..." with the same example to play it to our advantage (of maintenance) for the other classes, too. So, let's perhaps maintain that consistency.
@stevhliu WDYT about that?
# @unittest.skip("We cannot run inference on this model with the current CI hardware") | ||
# TODO (DN6, sayakpaul): move these tests to a beefier GPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can happily go after we add the following two markers:
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will put on my todos because seems like Flux also is using the same marker, where I copied from. For the next few days, I have a few other things I'd like to PoC or work on, so will take up the test refactor for this soon otherwise this PR might get stalled more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. We have it in the Flux Control tests, already:
diffusers/tests/lora/test_lora_layers_flux.py
Line 942 in b94cfd7
@require_big_gpu_with_torch_cuda |
Flux LoRA ones will be in after #9845 is merged. Since we already have a test suite for LoRA that uses the big model marker, I think it's fine to utilize that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be a blocker to do it in separate PR? If not, will revert the copied from changes and proceed to merge as this seems like something folks want without more delay, and I don't really have the bandwidth atm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine by me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments but nothing too critical.
I can push the changes I am requesting for if you're busy with other things. LMK and I will push the changes if okay. |
Ok thanks, fine by me if you want to take it up 😄 The slices were obtained on DGX as documented (similar to how Flux lora slices are from audace because we OOM on CPU for Hunyuan), so I'm not sure if it will be compatible with our CI |
Let me complete the TODOs and merge. Thanks, Aryan! |
Will merge after updating the test slices from the big GPU of our CI. Running it currently. |
def fuse_lora( | ||
self, | ||
components: List[str] = ["transformer"], | ||
components: List[str] = ["transformer", "text_encoder"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
text_encoder
lora is not supported, which is why I'd removed it. So just a FYI that I haven't tested the inference code with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have made a TODO for myself to deal with it in a follow-up PR :) Thanks!
@a-r-r-o-w the slices worked on the big CI GPU, too. All good! Thanks for working on this! |
minor one - using |
Not sure what you mean.
It's supported. This test confirms that:
Let us know if there's any problem. |
* update * fix make copies * update * add relevant markers to the integration test suite. * add copied. * fox-copies * temporarily add print. * directly place on CUDA as CPU isn't that big on the CIO. * fixes to fuse_lora, aryan was right. * fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Fixes #10106 (comment)
Some LoRAs for testing:
cc @svjack Would you be able to give this a try?