-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[LoRA Attn Processors] Refactor LoRA Attn Processors #4765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Necessary for #4473 |
Did some first pass and left comments. At this point, the changes seem a little too complicated to me but will probably have a better idea once the PR is in a more comprehensive state. |
…iffusers into refactor_lora_attn
A bit more explanation for what's going on here:
Overall this PR should make the whole LoRA loading and functioning easier to understand and build upon (such as needed for #4473 |
@@ -751,7 +753,7 @@ def test_a1111(self): | |||
images = images[0, -3:, -3:, -1].flatten() | |||
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) | |||
|
|||
self.assertTrue(np.allclose(images, expected, atol=1e-4)) | |||
self.assertTrue(np.allclose(images, expected, atol=1e-3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1e-3
is a bit too tight for integration tests. We should be able to run most tests on many different GPUs without having them fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You meant 1e-4.
@@ -921,9 +923,9 @@ def test_sdxl_0_9_lora_three(self): | |||
).images | |||
|
|||
images = images[0, -3:, -3:, -1].flatten() | |||
expected = np.array([0.4115, 0.4047, 0.4124, 0.3931, 0.3746, 0.3802, 0.3735, 0.3748, 0.3609]) | |||
expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating the expected values here because I'm quite sure that this PR corrects a bug.
The lora:
lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"
is quite unique since it has different network_alphas
for q,k,v and out LoRA of the attention layer. Previously we always set the same network_alpha
for all q,k,v and out (see here)
However, it might very well be that q,k,v and out have different network_alpha
values. This is now possible as of this PR and therefore should mean it's a bug fix since this checkpoint has different network_alpha
s for each q,k,v and out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diffusers/src/diffusers/loaders.py
Line 1407 in 3bba44d
for name, attn_module in text_encoder_attn_modules(text_encoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah got it. We do that for the text encoder only it seems but not the UNet, is my understanding correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly! I'm pretty sure it's corrected now :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice nice
As soon as you're good with the PR @sayakpaul - I think we can merge this one and unblock the fuse_lora PR :-) |
# 2. else it is not posssible that only some layers have LoRA activated | ||
if not all(is_lora_activated.values()): | ||
raise ValueError( | ||
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should just display list(is_lora_activated.keys())
to the end user.
hasattr(self, "processor") | ||
and not isinstance(processor, LORA_ATTENTION_PROCESSORS) | ||
and self.to_q.lora_layer is not None | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat.
I am a little concerned about:
and self.to_q.lora_layer is not None
i.e., only check if to_q.lora_layer
is not None. Is there a better alternative?
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" | ||
) | ||
|
||
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit is nasty but needed for ensuring backward compatibility.
@@ -1678,3 +1354,287 @@ def forward(self, f, zq): | |||
norm_f = self.norm_layer(f) | |||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) | |||
return new_f | |||
|
|||
|
|||
## Deprecated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to directly throw a deprecation message when the class is initialized or would that be too brutal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this too - think this will be a bit too much though at the moment because it would mean every time we call unet.attn_processors
we would get a deprecation warning which would be every time we save LoRAs. I'd suggest to only start throwing aggressive deprecation warnings when we do the peft integration
return attn.processor(attn, hidden_states, *args, **kwargs) | ||
|
||
|
||
class LoRAAttnProcessor2_0(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise.
attn.processor = AttnProcessor2_0() | ||
return attn.processor(attn, hidden_states, *args, **kwargs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like wise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard work. Thanks a lot.
Left some minor design-related comments but I don't think they block this PR from getting merged. So, I am gonna merge this to be able to focus on #4473.
* [LoRA Attn] Refactor LoRA attn * correct for network alphas * fix more * fix more tests * fix more tests * Move below * Finish * better version * correct serialization format * fix * fix more * fix more * fix more * Apply suggestions from code review * Update src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py * deprecation * relax atol for slow test slighly * Finish tests * make style * make style
* [LoRA Attn] Refactor LoRA attn * correct for network alphas * fix more * fix more tests * fix more tests * Move below * Finish * better version * correct serialization format * fix * fix more * fix more * fix more * Apply suggestions from code review * Update src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py * deprecation * relax atol for slow test slighly * Finish tests * make style * make style
What does this PR do?
This PR deprecates all the "LoRA..." attention processors as explained here: #4473 (comment)
This is due to the following reasons:
LoRACompatibleLinear
there is no reason to not also leverage it in the attention classes. This way we have a single point of logic, removing overhead.fuse_lora
as seen here:Fuse loras #4473
TODO:
attn_processor
method to (for now) return derprecated attn processor if LoRA layers are activatedThe PR works now for inference:
@sayakpaul @williamberman I'd be very happy if you could give this a first review and if ok for you I can finish the final TODOs tomorrow.