-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[LoRA] Support original format loras for HunyuanVideo #10376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
66fc85e
893b9c0
904e3a4
8be9180
a040c5d
f682d76
63d5e9f
4ac0c12
5fbc59c
95a7e0f
738f50d
23854f2
2cc3683
3b64bd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa | ||
from .lora_conversion_utils import ( | ||
_convert_bfl_flux_control_lora_to_diffusers, | ||
_convert_hunyuan_video_lora_to_diffusers, | ||
_convert_kohya_flux_lora_to_diffusers, | ||
_convert_non_diffusers_lora_to_diffusers, | ||
_convert_xlabs_flux_lora_to_diffusers, | ||
|
@@ -4007,7 +4008,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): | |
|
||
@classmethod | ||
@validate_hf_hub_args | ||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict | ||
def lora_state_dict( | ||
cls, | ||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
|
@@ -4018,7 +4018,7 @@ def lora_state_dict( | |
|
||
<Tip warning={true}> | ||
|
||
We support loading A1111 formatted LoRA checkpoints in a limited capacity. | ||
We support loading original format HunyuanVideo LoRA checkpoints. | ||
|
||
This function is experimental and might change in the future. | ||
|
||
|
@@ -4101,6 +4101,10 @@ def lora_state_dict( | |
logger.warning(warn_msg) | ||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | ||
|
||
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict) | ||
if is_original_hunyuan_video: | ||
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) | ||
|
||
return state_dict | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | ||
|
@@ -4239,10 +4243,9 @@ def save_lora_weights( | |
safe_serialization=safe_serialization, | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could leverage the CogVideoX fuse_lora for the "Copy" statement, no? If so, I'd prefer that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really because we have a hunyuan specific example here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we follow "Copied from ..." with the same example to play it to our advantage (of maintenance) for the other classes, too. So, let's perhaps maintain that consistency. @stevhliu WDYT about that? |
||
def fuse_lora( | ||
self, | ||
components: List[str] = ["transformer", "text_encoder"], | ||
components: List[str] = ["transformer"], | ||
lora_scale: float = 1.0, | ||
safe_fusing: bool = False, | ||
adapter_names: Optional[List[str]] = None, | ||
|
@@ -4283,8 +4286,7 @@ def fuse_lora( | |
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer | ||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): | ||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): | ||
r""" | ||
Reverses the effect of | ||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). | ||
|
Uh oh!
There was an error while loading. Please reload this page.