Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Support original format loras for HunyuanVideo #10376

Merged
merged 14 commits into from
Jan 7, 2025

Conversation

a-r-r-o-w
Copy link
Member

Fixes #10106 (comment)

Some LoRAs for testing:

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
pipe.load_lora_weights("svjack/Genshin_Impact_XiangLing_Low_Res_HunyuanVideo_lora_early", weight_name="xiangling_ep1_lora.safetensors", adapter_name="hunyuan-lora")
pipe.set_adapters("hunyuan-lora", 0.8)
pipe.vae.enable_tiling()
pipe.to("cuda")

output = pipe(
    prompt=".....",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
    generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_video(output, "output.mp4", fps=15)

cc @svjack Would you be able to give this a try?

@a-r-r-o-w a-r-r-o-w requested review from DN6 and sayakpaul December 25, 2024 10:02
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Dec 25, 2024

@sayakpaul @DN6 Not comfortable adding integrations tests with any of the existing original-format loras on the Hub/CivitAI for HunyuanVideo. They are either Anime/NSFW/Person-related loras which we usually don't use for testing. I'll work on training a style lora over the next few days and convert it to original format for integration tests. Does that sound good?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Okay to merge for me after an integration test!

@nitinmukesh
Copy link

Subscribing to this PR

@yardenfren1996
Copy link

Hi, does Diffusers have a training script available for training HunyuanVideo LoRAs?

@sayakpaul
Copy link
Member

We do, thanks to @a-r-r-o-w 🐐
Check out https://github.com/a-r-r-o-w/finetrainers

@vladmandic
Copy link
Contributor

is this pr stalled just because there are no loras that are (very subjectively) deemed suitable for integration test?

@a-r-r-o-w
Copy link
Member Author

@vladmandic No, sorry that's not it. I didn't find the time to get back to this yet. We have loras suitable for integration tests now, so will complete the PR tomorrow

@a-r-r-o-w
Copy link
Member Author

Okay, so there's a problem with the loras we trained 🤣 We didn't train the proj_mlp layer in the single transformer blocks, which makes it so that we can't concatenate to_q, to_k, to_v, proj_mlp into a single lora layer (which is the original model would use). The workaround is to just append zeroed out weights to diffusers-format checkpoint, but I don't think we should spend much time on this at the moment, because it might require additional debugging in case it doesn't work as expected.

In the interest of time, I'm just going to use this checkpoint for the integration tests: https://huggingface.co/Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1/

partial conversion script
import argparse
import os
from collections import defaultdict
from pathlib import Path

import torch
from safetensors.torch import load_file, save_file


def convert_diffusers_to_hunyuan_video_lora(diffusers_state_dict):
    converted_state_dict = {k: diffusers_state_dict.pop(k) for k in list(diffusers_state_dict.keys())}

    TRANSFORMER_KEYS_RENAME_DICT = {
        "img_in": "x_embedder",
        "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
        "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
        "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
        "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
        "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
        "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
        ".double_blocks": ".transformer_blocks",
        ".single_blocks": ".single_transformer_blocks",
        "img_attn_q_norm": "attn.norm_q",
        "img_attn_k_norm": "attn.norm_k",
        "img_attn_proj": "attn.to_out.0",
        "txt_attn_q_norm": "attn.norm_added_q",
        "txt_attn_k_norm": "attn.norm_added_k",
        "txt_attn_proj": "attn.to_add_out",
        "img_mod.linear": "norm1.linear",
        "img_norm1": "norm1.norm",
        "img_norm2": "norm2",
        "img_mlp": "ff",
        "txt_mod.linear": "norm1_context.linear",
        "txt_norm1": "norm1.norm",
        "txt_norm2": "norm2_context",
        "txt_mlp": "ff_context",
        "self_attn_proj": "attn.to_out.0",
        "modulation.linear": "norm.linear",
        "pre_norm": "norm.norm",
        "final_layer.norm_final": "norm_out.norm",
        "final_layer.linear": "proj_out",
        "fc1": "net.0.proj",
        "fc2": "net.2",
        "input_embedder": "proj_in",
        # txt_in
        "individual_token_refiner.blocks": "token_refiner.refiner_blocks",
        "adaLN_modulation.1": "norm_out.linear",
        "txt_in": "context_embedder",
        "t_embedder.mlp.0": "time_text_embed.timestep_embedder.linear_1",
        "t_embedder.mlp.2": "time_text_embed.timestep_embedder.linear_2",
        "c_embedder": "time_text_embed.text_embedder",
        "mlp": "ff",
    }
    
    TRANSFORMER_KEYS_RENAME_DICT_REVERSE = {v: k for k, v in TRANSFORMER_KEYS_RENAME_DICT.items()}

    # TRANSFORMER_SPECIAL_KEYS_REMAP_BACK = {
    #     "norm_out.linear": remap_norm_scale_shift_back_,
    #     "context_embedder": remap_txt_in_back_,
    #     "attn.to_q": remap_img_attn_qkv_back_,
    #     "attn.add_q_proj": remap_txt_attn_qkv_back_,
    #     "single_transformer_blocks": remap_single_transformer_blocks_back_,
    # }
    
    for key in list(converted_state_dict.keys()):
        if "norm_out.linear" in key:
            weight = converted_state_dict.pop(key)
            scale, shift = weight.chunk(2, dim=0)
            new_weight = torch.cat([shift, scale], dim=0)
            converted_state_dict[key] = new_weight
        
        if "to_q" in key:
            if "single_transformer_blocks" in key:
                to_q = converted_state_dict.pop(key)
                to_k = converted_state_dict.pop(key.replace("to_q", "to_k"))
                to_v = converted_state_dict.pop(key.replace("to_q", "to_v"))
                to_out = converted_state_dict.pop(key.replace("attn.to_q", "proj_mlp"))
                rename_attn_key = "img_attn_qkv"
                if "lora_A" in key:
                    converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = to_q
                else:
                    qkv_mlp = torch.cat([to_q, to_k, to_v, to_out], dim=0)
                    converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = qkv_mlp
            else:
                to_q = converted_state_dict.pop(key)
                to_k = converted_state_dict.pop(key.replace("to_q", "to_k"))
                to_v = converted_state_dict.pop(key.replace("to_q", "to_v"))
                if "token_refiner" in key:
                    rename_attn_key = "self_attn_qkv"
                    if "lora_A" in key:
                        converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = to_q
                    else:
                        qkv = torch.cat([to_q, to_k, to_v], dim=0)
                        converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = qkv
                else:
                    rename_attn_key = "img_attn_qkv"
                    if "lora_A" in key:
                        converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = to_q
                    else:
                        qkv = torch.cat([to_q, to_k, to_v], dim=0)
                        converted_state_dict[key.replace("attn.to_q", rename_attn_key)] = qkv
        
        if "add_q_proj" in key:
            to_q = converted_state_dict.pop(key)
            to_k = converted_state_dict.pop(key.replace("add_q_proj", "add_k_proj"))
            to_v = converted_state_dict.pop(key.replace("add_q_proj", "add_v_proj"))
            rename_attn_key = "txt_attn_qkv"
            if "lora_A" in key:
                converted_state_dict[key.replace("attn.add_q_proj", rename_attn_key)] = to_q
            else:
                qkv = torch.cat([to_q, to_k, to_v], dim=0)
                converted_state_dict[key.replace("attn.add_q_proj", rename_attn_key)] = qkv
    
    for key in list(converted_state_dict.keys()):
        new_key = key[:]
        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT_REVERSE.items():
            new_key = new_key.replace(replace_key, rename_key)
        converted_state_dict[new_key] = converted_state_dict.pop(key)

    # Remove "transformer." prefix
    for key in list(converted_state_dict.keys()):
        if key.startswith("transformer."):
            converted_state_dict[key[len("transformer."):]] = converted_state_dict.pop(key)
    
    # Add back "diffusion_model." prefix
    for key in list(converted_state_dict.keys()):
        converted_state_dict[f"diffusion_model.{key}"] = converted_state_dict.pop(key)

    return converted_state_dict


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt_path", type=str, required=True)
    parser.add_argument("--output_path_or_name", type=str, required=True)
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    
    if args.ckpt_path.endswith(".pt"):
        diffusers_state_dict = torch.load(args.ckpt_path, map_location="cpu", weights_only=True)
    elif args.ckpt_path.endswith(".safetensors"):
        diffusers_state_dict = load_file(args.ckpt_path)

    original_format_state_dict = convert_diffusers_to_hunyuan_video_lora(diffusers_state_dict)

    output_path_or_name = Path(args.output_path_or_name)
    if output_path_or_name.as_posix().endswith(".safetensors"):
        os.makedirs(output_path_or_name.parent, exist_ok=True)
        save_file(original_format_state_dict, output_path_or_name)
    else:
        os.makedirs(output_path_or_name, exist_ok=True)
        output_path_or_name = output_path_or_name / "pytorch_lora_weights.safetensors"
        save_file(original_format_state_dict, output_path_or_name)

@a-r-r-o-w a-r-r-o-w force-pushed the original-lora-hunyuan-video branch from 35d62aa to 8be9180 Compare January 6, 2025 23:51
@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul January 6, 2025 23:51
@a-r-r-o-w
Copy link
Member Author

It seems like we don't have anything yet to test all kinds of conversion paths. Not really a problem IMO. The loras available so far seem to only train the double_transformer_blocks attention layers and single_transformer_blocks linear projections. We can add to the tests eventually as more loras drop

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 7, 2025

Fix for styling error: #10478

@sayakpaul
Copy link
Member

It seems like we don't have anything yet to test all kinds of conversion paths. Not really a problem IMO. The loras available so far seem to only train the double_transformer_blocks attention layers and single_transformer_blocks linear projections. We can add to the tests eventually as more loras drop

This is perfectly fine and should be the way to go. We cannot test all kinds of community LoRAs from the get-go and something we have been following for a while. Totally okay with me!

@@ -4239,10 +4243,9 @@ def save_lora_weights(
safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could leverage the CogVideoX fuse_lora for the "Copy" statement, no? If so, I'd prefer that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really because we have a hunyuan specific example here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we follow "Copied from ..." with the same example to play it to our advantage (of maintenance) for the other classes, too. So, let's perhaps maintain that consistency.

@stevhliu WDYT about that?

Comment on lines 197 to 198
# @unittest.skip("We cannot run inference on this model with the current CI hardware")
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can happily go after we add the following two markers:

  • @require_big_gpu_with_torch_cuda
  • @pytest.mark.big_gpu_with_torch_cuda

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will put on my todos because seems like Flux also is using the same marker, where I copied from. For the next few days, I have a few other things I'd like to PoC or work on, so will take up the test refactor for this soon otherwise this PR might get stalled more

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. We have it in the Flux Control tests, already:

@require_big_gpu_with_torch_cuda

Flux LoRA ones will be in after #9845 is merged. Since we already have a test suite for LoRA that uses the big model marker, I think it's fine to utilize that here.

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be a blocker to do it in separate PR? If not, will revert the copied from changes and proceed to merge as this seems like something folks want without more delay, and I don't really have the bandwidth atm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine by me.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments but nothing too critical.

@sayakpaul
Copy link
Member

I can push the changes I am requesting for if you're busy with other things. LMK and I will push the changes if okay.

@a-r-r-o-w
Copy link
Member Author

Ok thanks, fine by me if you want to take it up 😄 The slices were obtained on DGX as documented (similar to how Flux lora slices are from audace because we OOM on CPU for Hunyuan), so I'm not sure if it will be compatible with our CI

@sayakpaul
Copy link
Member

Let me complete the TODOs and merge. Thanks, Aryan!

@sayakpaul
Copy link
Member

Will merge after updating the test slices from the big GPU of our CI. Running it currently.

def fuse_lora(
self,
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "text_encoder"],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text_encoder lora is not supported, which is why I'd removed it. So just a FYI that I haven't tested the inference code with this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have made a TODO for myself to deal with it in a follow-up PR :) Thanks!

@sayakpaul sayakpaul merged commit 811560b into main Jan 7, 2025
15 checks passed
@sayakpaul sayakpaul deleted the original-lora-hunyuan-video branch January 7, 2025 07:49
@sayakpaul
Copy link
Member

@a-r-r-o-w the slices worked on the big CI GPU, too. All good! Thanks for working on this!

@vladmandic
Copy link
Contributor

minor one - using fuse_lora() followed by unload_lora_weights() as usual will make hunyuanvideo lora silently go away.
if fuse_lora is not supported for hunyuanvideo, that's fine, just note it?

@sayakpaul
Copy link
Member

fuse_lora() followed by unload_lora_weights() as usual will make hunyuanvideo lora silently go away.

Not sure what you mean. fuse_lora() will fuse the LoRA params into the corresponding base models. unload_lora_weights() will completely offload the LoRA params if loaded to preserve memory.

if fuse_lora is not supported for hunyuanvideo, that's fine, just note it?

It's supported. This test confirms that:

self.pipeline.fuse_lora()

Let us know if there's any problem.

DN6 pushed a commit that referenced this pull request Jan 15, 2025
* update

* fix make copies

* update

* add relevant markers to the integration test suite.

* add copied.

* fox-copies

* temporarily add print.

* directly place on CUDA as CPU isn't that big on the CIO.

* fixes to fuse_lora, aryan was right.

* fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add HunyuanVideo
6 participants