Closed
Description
Is your feature request related to a problem? Please describe.
LoRA is consuming extra GPU power that is not required in inference. This can cause a noticable amount of overhead especially with high resolutions. For higher LoRA ranks (64/128), this can be up to 30%-50% performance gain.
Describe the solution you'd like
A bake_lora_weights
function on UNet models (or any model that supports attention processor).
Uses can decide whether they want to keep swapping LoRA weights available or to bake weights to reduce inference overhead.
Proof of Concept:
@torch.no_grad()
def bake_lora_weights(self):
def expansion(lora: LoRALinearLayer):
# in x in -> in x out
eye = torch.eye(lora.down.weight.size(1))
eye = eye.to(lora.down.weight)
return lora(eye).T
def traverse(name: str, module: Attention):
if hasattr(module, "set_processor"):
assert isinstance(module.processor, (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor))
module.to_k.weight += expansion(module.processor.to_k_lora)
module.to_v.weight += expansion(module.processor.to_v_lora)
module.to_q.weight += expansion(module.processor.to_q_lora)
module.to_out[0].weight += expansion(module.processor.to_out_lora)
for sub_name, child in module.named_children():
traverse(f"{name}.{sub_name}", child)
for name, module in self.named_children():
traverse(name, module)
if torch.__version__ >= '2.0':
self.unet.set_attn_processor(AttnProcessor2_0())
elif is_xformers_available():
self.unet.set_attn_processor(XFormersAttnProcessor())
else:
self.unet.set_attn_processor(AttnProcessor())
Describe alternatives you've considered
N/A
Additional context
N/A
Metadata
Metadata
Assignees
Labels
No labels