Skip to content

Commit ca06442

Browse files
[SDXL Lora] Fix last ben sdxl lora (huggingface#4797)
* Fix last ben sdxl lora * Correct typo * make style
1 parent 9d71f4e commit ca06442

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

loaders.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ def lora_state_dict(
10841084
# Map SDXL blocks correctly.
10851085
if unet_config is not None:
10861086
# use unet config to remap block numbers
1087-
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
1087+
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
10881088
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
10891089

10901090
return state_dict, network_alphas
@@ -1121,24 +1121,41 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
11211121
return weight_name
11221122

11231123
@classmethod
1124-
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
1125-
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
1124+
def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
1125+
# 1. get all state_dict_keys
1126+
all_keys = state_dict.keys()
1127+
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
1128+
1129+
# 2. check if needs remapping, if not return original dict
1130+
is_in_sgm_format = False
1131+
for key in all_keys:
1132+
if any(p in key for p in sgm_patterns):
1133+
is_in_sgm_format = True
1134+
break
1135+
1136+
if not is_in_sgm_format:
1137+
return state_dict
1138+
1139+
# 3. Else remap from SGM patterns
11261140
new_state_dict = {}
11271141
inner_block_map = ["resnets", "attentions", "upsamplers"]
11281142

11291143
# Retrieves # of down, mid and up blocks
11301144
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
1131-
for layer in state_dict:
1132-
if "text" not in layer:
1145+
1146+
for layer in all_keys:
1147+
if "text" in layer:
1148+
new_state_dict[layer] = state_dict.pop(layer)
1149+
else:
11331150
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
1134-
if "input_blocks" in layer:
1151+
if sgm_patterns[0] in layer:
11351152
input_block_ids.add(layer_id)
1136-
elif "middle_block" in layer:
1153+
elif sgm_patterns[1] in layer:
11371154
middle_block_ids.add(layer_id)
1138-
elif "output_blocks" in layer:
1155+
elif sgm_patterns[2] in layer:
11391156
output_block_ids.add(layer_id)
11401157
else:
1141-
raise ValueError("Checkpoint not supported")
1158+
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
11421159

11431160
input_blocks = {
11441161
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
@@ -1201,12 +1218,8 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl
12011218
)
12021219
new_state_dict[new_key] = state_dict.pop(key)
12031220

1204-
if is_all_unet and len(state_dict) > 0:
1221+
if len(state_dict) > 0:
12051222
raise ValueError("At this point all state dict entries have to be converted.")
1206-
else:
1207-
# Remaining is the text encoder state dict.
1208-
for k, v in state_dict.items():
1209-
new_state_dict.update({k: v})
12101223

12111224
return new_state_dict
12121225

0 commit comments

Comments
 (0)