Skip to content

Commit

Permalink
update lora
Browse files Browse the repository at this point in the history
  • Loading branch information
JunnYu committed Feb 20, 2023
1 parent 50cd4cf commit ca5130d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
8 changes: 7 additions & 1 deletion ppdiffusers/examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ def parse_args(input_args=None):
" resolution"
),
)
parser.add_argument(
"--lora_rank",
type=int,
default=128,
help="The rank of lora linear.",
)
parser.add_argument(
"--center_crop",
default=False,
Expand Down Expand Up @@ -650,7 +656,7 @@ def main():
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.lora_rank
)

unet.set_attn_processor(lora_attn_procs)
Expand Down
12 changes: 8 additions & 4 deletions ppdiffusers/ppdiffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,14 @@ class LoRACrossAttnProcessor(nn.Layer):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)

def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,19 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
if len(attentions) == 2:
attentions = []

if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]

# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []

if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
Expand Down

0 comments on commit ca5130d

Please sign in to comment.