[LoRA] feat: support loading loras into 4bit quantized Flux models.#10578
[LoRA] feat: support loading loras into 4bit quantized Flux models.#10578
Conversation
| if quantization_config.load_in_4bit: | ||
| expansion_shape = torch.Size(expansion_shape).numel() | ||
| expansion_shape = ((expansion_shape + 1) // 2, 1) |
There was a problem hiding this comment.
Only 4bit bnb models flatten.
There was a problem hiding this comment.
How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1".
I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too.
There was a problem hiding this comment.
Yeah, this comes from bitsandbytes. Cc: @matthewdouglas
There was a problem hiding this comment.
This is for rounding, i.e. if expansion_shape is odd it will have an additional 8bit tensor with just one value packed into it.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for fixing the issue with loading LoRA into Flux models that are quantized with 4 bit bnb.
The issue of the actual parameter shape is a common trap, it would be great if the original shape could be retrieved from bnb, not sure if there is a method for that.
Sayak, could you quickly sketch why this used to work and no longer does? IIRC, there was no special handling for 4 bit bnb previously, was there?
| if quantization_config.load_in_4bit: | ||
| expansion_shape = torch.Size(expansion_shape).numel() | ||
| expansion_shape = ((expansion_shape + 1) // 2, 1) |
There was a problem hiding this comment.
How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1".
I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too.
| module_weight_shape = module_weight.shape | ||
| expansion_shape = (out_features, in_features) | ||
| quantization_config = getattr(transformer, "quantization_config", None) | ||
| if quantization_config and quantization_config.quant_method == "bitsandbytes": |
There was a problem hiding this comment.
Would it make sense to have a utility function to get the shape to avoid code duplication?
There was a problem hiding this comment.
Yes, this needs to happen. As I mentioned this PR is very much a PoC to gather feedback and I will refine it. But I wanted to first explore if this a good way to approach the problem.
It used to because we didn't have any support for loading Flux Control LoRA, some relevant pieces of vital code: diffusers/src/diffusers/loaders/lora_pipeline.py Line 1941 in be62c85 diffusers/src/diffusers/loaders/lora_pipeline.py Line 2055 in be62c85 |
Okay, so Flux control LoRA + 4bit bnb was never possible. From the initial description, I got the wrong impression that this is a full regression. |
|
We fully agree there. What I meant is that I misunderstood your original message to mean that Flux control LoRA + 4 bit bnb used to work and was trying to understand why it breaks now. This is what I meant by "full regression". |
@BenjaminBossan That's a good point. The |
Thanks, @BenjaminBossan. I will work on the suggestions to make it ready for another review pass.
Thanks @matthewdouglas! Will try to see if we can use it here. |
|
Just confirmed that this works: from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
print(model.context_embedder.weight.quant_state.shape)
# torch.Size([3072, 4096]) |
|
@BenjaminBossan I have addressed the rest of the feedback. PTAL. @DN6 ready for your review, too. Just as FYI, I have run all the LoRA related integration tests for Flux and they all pass. |
| base_weight_param_name: str = None, | ||
| ) -> "torch.Size": | ||
| def _get_weight_shape(weight: torch.Tensor): | ||
| return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape |
There was a problem hiding this comment.
8bit params preserve the original shape?
| module_path = ( | ||
| base_weight_param_name.rsplit(".weight", 1)[0] | ||
| if base_weight_param_name.endswith(".weight") | ||
| else base_weight_param_name |
There was a problem hiding this comment.
In what case would this param name not end with weight? Since the subsequent call to get_weight_shape assumes that there is a weight attribute?
There was a problem hiding this comment.
True. Let me modify that. Thanks for catching.
|
@DN6 applied changes and re-run all the required tests and they passed. |
|
Thanks for fixing this, Sayak. |
…10578) * feat: support loading loras into 4bit quantized models. * updates * update * remove weight check.
What does this PR do?
We broke Flux LoRA (not Control LoRA) loading for 4bit BnB Flux in 0.32.0, when supporting Flux Control LoRAs (yeah only applies to Flux).
To reproduce:
Code
This above code will run with
v0.31.0-releasebranch but will fail withv0.32.0-releasealong withmain. This went uncaught because we don't test for it.This PR attempts to partially fix the problem so that we can at least resort to a behavior similar to what was happening in the
v0.31.0-releasebranch. Want to use this PR to refine how we're doing that. I want to ship this PR first and tackle the TODOs in a follow-up. Once this PR is done, we might have to do a patch release.Related issue: #10550