Skip to content

Commit 1bc6f3d

Browse files
linoytsabangithub-actions[bot]sayakpaul
authored
[LoRA training] update metadata use for lora alpha + README (#11723)
* lora alpha * Apply style fixes * Update examples/advanced_diffusion_training/README_flux.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * fix readme format --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 79bd7ec commit 1bc6f3d

File tree

4 files changed

+98
-3
lines changed

4 files changed

+98
-3
lines changed

examples/advanced_diffusion_training/README_flux.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
7676
> `pip install wandb`
7777
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
7878
79+
### LoRA Rank and Alpha
80+
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
81+
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
82+
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
83+
- lora_alpha vs. rank:
84+
This ratio dictates the LoRA's effective strength:
85+
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
86+
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
87+
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
88+
89+
> [!TIP]
90+
> A common starting point is to set `lora_alpha` equal to `rank`.
91+
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
92+
> to give the LoRA updates more influence without increasing parameter count.
93+
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
94+
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
95+
96+
7997
### Target Modules
8098
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
8199
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore

examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -281,3 +284,45 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult
281284
run_command(self._launch_args + resume_run_args)
282285

283286
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
287+
288+
def test_dreambooth_lora_with_metadata(self):
289+
# Use a `lora_alpha` that is different from `rank`.
290+
lora_alpha = 8
291+
rank = 4
292+
with tempfile.TemporaryDirectory() as tmpdir:
293+
test_args = f"""
294+
{self.script_path}
295+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
296+
--instance_data_dir {self.instance_data_dir}
297+
--instance_prompt {self.instance_prompt}
298+
--resolution 64
299+
--train_batch_size 1
300+
--gradient_accumulation_steps 1
301+
--max_train_steps 2
302+
--lora_alpha={lora_alpha}
303+
--rank={rank}
304+
--learning_rate 5.0e-04
305+
--scale_lr
306+
--lr_scheduler constant
307+
--lr_warmup_steps 0
308+
--output_dir {tmpdir}
309+
""".split()
310+
311+
run_command(self._launch_args + test_args)
312+
# save_pretrained smoke test
313+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
314+
self.assertTrue(os.path.isfile(state_dict_file))
315+
316+
# Check if the metadata was properly serialized.
317+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
318+
metadata = f.metadata() or {}
319+
320+
metadata.pop("format", None)
321+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
322+
if raw:
323+
raw = json.loads(raw)
324+
325+
loaded_lora_alpha = raw["transformer.lora_alpha"]
326+
self.assertTrue(loaded_lora_alpha == lora_alpha)
327+
loaded_lora_rank = raw["transformer.r"]
328+
self.assertTrue(loaded_lora_rank == rank)

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
)
5656
from diffusers.optimization import get_scheduler
5757
from diffusers.training_utils import (
58+
_collate_lora_metadata,
5859
_set_state_dict_into_text_encoder,
5960
cast_training_params,
6061
compute_density_for_timestep_sampling,
@@ -431,6 +432,13 @@ def parse_args(input_args=None):
431432
help=("The dimension of the LoRA update matrices."),
432433
)
433434

435+
parser.add_argument(
436+
"--lora_alpha",
437+
type=int,
438+
default=4,
439+
help="LoRA alpha to be used for additional scaling.",
440+
)
441+
434442
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
435443

436444
parser.add_argument(
@@ -1556,7 +1564,7 @@ def main(args):
15561564
# now we will add new LoRA weights to the attention layers
15571565
transformer_lora_config = LoraConfig(
15581566
r=args.rank,
1559-
lora_alpha=args.rank,
1567+
lora_alpha=args.lora_alpha,
15601568
lora_dropout=args.lora_dropout,
15611569
init_lora_weights="gaussian",
15621570
target_modules=target_modules,
@@ -1565,7 +1573,7 @@ def main(args):
15651573
if args.train_text_encoder:
15661574
text_lora_config = LoraConfig(
15671575
r=args.rank,
1568-
lora_alpha=args.rank,
1576+
lora_alpha=args.lora_alpha,
15691577
lora_dropout=args.lora_dropout,
15701578
init_lora_weights="gaussian",
15711579
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -1582,13 +1590,15 @@ def save_model_hook(models, weights, output_dir):
15821590
if accelerator.is_main_process:
15831591
transformer_lora_layers_to_save = None
15841592
text_encoder_one_lora_layers_to_save = None
1585-
1593+
modules_to_save = {}
15861594
for model in models:
15871595
if isinstance(model, type(unwrap_model(transformer))):
15881596
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1597+
modules_to_save["transformer"] = model
15891598
elif isinstance(model, type(unwrap_model(text_encoder_one))):
15901599
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
15911600
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1601+
modules_to_save["text_encoder"] = model
15921602
elif isinstance(model, type(unwrap_model(text_encoder_two))):
15931603
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
15941604
else:
@@ -1601,6 +1611,7 @@ def save_model_hook(models, weights, output_dir):
16011611
output_dir,
16021612
transformer_lora_layers=transformer_lora_layers_to_save,
16031613
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
1614+
**_collate_lora_metadata(modules_to_save),
16041615
)
16051616
if args.train_text_encoder_ti:
16061617
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
@@ -2359,16 +2370,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23592370
# Save the lora layers
23602371
accelerator.wait_for_everyone()
23612372
if accelerator.is_main_process:
2373+
modules_to_save = {}
23622374
transformer = unwrap_model(transformer)
23632375
if args.upcast_before_saving:
23642376
transformer.to(torch.float32)
23652377
else:
23662378
transformer = transformer.to(weight_dtype)
23672379
transformer_lora_layers = get_peft_model_state_dict(transformer)
2380+
modules_to_save["transformer"] = transformer
23682381

23692382
if args.train_text_encoder:
23702383
text_encoder_one = unwrap_model(text_encoder_one)
23712384
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
2385+
modules_to_save["text_encoder"] = text_encoder_one
23722386
else:
23732387
text_encoder_lora_layers = None
23742388

@@ -2377,6 +2391,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23772391
save_directory=args.output_dir,
23782392
transformer_lora_layers=transformer_lora_layers,
23792393
text_encoder_lora_layers=text_encoder_lora_layers,
2394+
**_collate_lora_metadata(modules_to_save),
23802395
)
23812396

23822397
if args.train_text_encoder_ti:

examples/dreambooth/README_flux.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \
170170
--push_to_hub
171171
```
172172

173+
### LoRA Rank and Alpha
174+
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
175+
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
176+
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
177+
- lora_alpha vs. rank:
178+
This ratio dictates the LoRA's effective strength:
179+
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
180+
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
181+
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
182+
183+
> [!TIP]
184+
> A common starting point is to set `lora_alpha` equal to `rank`.
185+
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
186+
> to give the LoRA updates more influence without increasing parameter count.
187+
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
188+
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
189+
173190
### Target Modules
174191
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
175192
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore

0 commit comments

Comments
 (0)