From d79f5913b35a872a94f8ffbd147b78dd62e7275c Mon Sep 17 00:00:00 2001 From: Soila Kavulya Date: Thu, 19 Dec 2024 02:11:07 -0800 Subject: [PATCH] Update save lora weights for diffusers with text_encoder_2 layers (#1626) --- .../diffusers/pipelines/pipeline_utils.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py index 6dda26f796..02e635eaf6 100644 --- a/optimum/habana/diffusers/pipelines/pipeline_utils.py +++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py @@ -394,13 +394,26 @@ def save_lora_weights( text_encoder_lora_layers = to_device_dtype(text_encoder_lora_layers, target_device=torch.device("cpu")) if text_encoder_2_lora_layers: text_encoder_2_lora_layers = to_device_dtype(text_encoder_2_lora_layers, target_device=torch.device("cpu")) - return super().save_lora_weights( - save_directory, - unet_lora_layers, - text_encoder_lora_layers, - text_encoder_2_lora_layers, - is_main_process, - weight_name, - save_function, - safe_serialization, - ) + + # text_encoder_2_lora_layers is only supported by some diffuser pipelines + if text_encoder_2_lora_layers: + return super().save_lora_weights( + save_directory, + unet_lora_layers, + text_encoder_lora_layers, + text_encoder_2_lora_layers, + is_main_process, + weight_name, + save_function, + safe_serialization, + ) + else: + return super().save_lora_weights( + save_directory, + unet_lora_layers, + text_encoder_lora_layers, + is_main_process, + weight_name, + save_function, + safe_serialization, + )