Skip to content

[LoRA Attn Processors] Refactor LoRA Attn Processors #4765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Aug 28, 2023
Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 24, 2023

What does this PR do?

This PR deprecates all the "LoRA..." attention processors as explained here: #4473 (comment)

This is due to the following reasons:

  • Now that we have the concept of LoRACompatibleLinear there is no reason to not also leverage it in the attention classes. This way we have a single point of logic, removing overhead.
  • It significantly helps when applying global LoRA functionality such as fuse_lora as seen here:
    Fuse loras #4473
  • We reduce the number of attention processor classes significantly making the file easier to read

TODO:

  • Refactor attn_processor method to (for now) return derprecated attn processor if LoRA layers are activated
  • Make sure the saving format stays the same
  • Make sure all LoRA training examples work
  • Fix failing test

The PR works now for inference:

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors")

pipe.to(torch_dtype=torch.float16)
pipe.to("cuda")

torch.manual_seed(0)

prompt = "beautiful scenery nature glass bottle landscape, , purple galaxy bottle"
negative_prompt = "text, watermark"

image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=25).images[0]

@sayakpaul @williamberman I'd be very happy if you could give this a first review and if ok for you I can finish the final TODOs tomorrow.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 24, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor Author

Necessary for #4473

@patrickvonplaten patrickvonplaten changed the title [LoRA Attn] Refactor LoRA attn [WIP][LoRA Attn Processors] Refactor LoRA Attn Processors Aug 24, 2023
@sayakpaul
Copy link
Member

Did some first pass and left comments. At this point, the changes seem a little too complicated to me but will probably have a better idea once the PR is in a more comprehensive state.

@patrickvonplaten
Copy link
Contributor Author

A bit more explanation for what's going on here:

  • a) Eventually we want to fully get rid of "lora" logic in diffusers and offload all this responsability to PEFT. The only thing that will be left in diffusers then will be {AnyLoraFormat}->PEFT_Format conversion logic which probably only amounts to renaming keys and other inputs.
  • b) In this PR in a first step we want to unify the LoRA logic all into LoRACompatibleLinear and essentially deprecate all Attn Processor LoRAs. After this PR, all the computation logic will flow through LoRACompatibleLinear and we keep the LoRA attention processors only as a serialization format to not break backwards compatibility. Hence we need to ensure that calling unet.attn_processors will still return LoRA attention processors so that all the serialization format stays intact.
  • c) Running the forward pass of any LoRA attention processor is already now deprecated. The LoRA attention processor serialization format will bit by bit be deprecated, but this should be done as part of the PEFT integration.

Overall this PR should make the whole LoRA loading and functioning easier to understand and build upon (such as needed for #4473

@@ -751,7 +753,7 @@ def test_a1111(self):
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])

self.assertTrue(np.allclose(images, expected, atol=1e-4))
self.assertTrue(np.allclose(images, expected, atol=1e-3))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-3 is a bit too tight for integration tests. We should be able to run most tests on many different GPUs without having them fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You meant 1e-4.

@@ -921,9 +923,9 @@ def test_sdxl_0_9_lora_three(self):
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4115, 0.4047, 0.4124, 0.3931, 0.3746, 0.3802, 0.3735, 0.3748, 0.3609])
expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating the expected values here because I'm quite sure that this PR corrects a bug.
The lora:

        lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
        lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"

is quite unique since it has different network_alphas for q,k,v and out LoRA of the attention layer. Previously we always set the same network_alpha for all q,k,v and out (see here)

However, it might very well be that q,k,v and out have different network_alpha values. This is now possible as of this PR and therefore should mean it's a bug fix since this checkpoint has different network_alphas for each q,k,v and out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for name, attn_module in text_encoder_attn_modules(text_encoder):
suggests we assign different alphas to the modules you mentioned. What am I missing out on?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it. We do that for the text encoder only it seems but not the UNet, is my understanding correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly! I'm pretty sure it's corrected now :-)

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice nice

@patrickvonplaten
Copy link
Contributor Author

As soon as you're good with the PR @sayakpaul - I think we can merge this one and unblock the fuse_lora PR :-)

# 2. else it is not posssible that only some layers have LoRA activated
if not all(is_lora_activated.values()):
raise ValueError(
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just display list(is_lora_activated.keys()) to the end user.

Comment on lines +308 to +311
hasattr(self, "processor")
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
and self.to_q.lora_layer is not None
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat.

I am a little concerned about:

and self.to_q.lora_layer is not None

i.e., only check if to_q.lora_layer is not None. Is there a better alternative?

f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
)

# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
Copy link
Member

@sayakpaul sayakpaul Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit is nasty but needed for ensuring backward compatibility.

@@ -1678,3 +1354,287 @@ def forward(self, f, zq):
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f


## Deprecated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to directly throw a deprecation message when the class is initialized or would that be too brutal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this too - think this will be a bit too much though at the moment because it would mean every time we call unet.attn_processors we would get a deprecation warning which would be every time we save LoRAs. I'd suggest to only start throwing aggressive deprecation warnings when we do the peft integration

return attn.processor(attn, hidden_states, *args, **kwargs)


class LoRAAttnProcessor2_0(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise.

attn.processor = AttnProcessor2_0()
return attn.processor(attn, hidden_states, *args, **kwargs)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like wise.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard work. Thanks a lot.

Left some minor design-related comments but I don't think they block this PR from getting merged. So, I am gonna merge this to be able to focus on #4473.

@sayakpaul sayakpaul merged commit 766aa50 into main Aug 28, 2023
@sayakpaul sayakpaul deleted the refactor_lora_attn branch August 28, 2023 05:08
@sayakpaul sayakpaul mentioned this pull request Aug 28, 2023
3 tasks
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [LoRA Attn] Refactor LoRA attn

* correct for network alphas

* fix more

* fix more tests

* fix more tests

* Move below

* Finish

* better version

* correct serialization format

* fix

* fix more

* fix more

* fix more

* Apply suggestions from code review

* Update src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

* deprecation

* relax atol for slow test slighly

* Finish tests

* make style

* make style
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* [LoRA Attn] Refactor LoRA attn

* correct for network alphas

* fix more

* fix more tests

* fix more tests

* Move below

* Finish

* better version

* correct serialization format

* fix

* fix more

* fix more

* fix more

* Apply suggestions from code review

* Update src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

* deprecation

* relax atol for slow test slighly

* Finish tests

* make style

* make style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants