Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions library/network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from torch import Tensor

# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
def lora_dropout_down(down: Tensor, x: Tensor, dropout_prob=0.5):
""" A = A · diag(mA), mA ∼ Bern(1 − p)"""
mask = torch.bernoulli(
torch.ones(down.shape[1], device=down.device) * (1 - dropout_prob)
)

# Apply input dimension mask (columns of down-projection)
lx = x @ (down * mask.view(1, -1)).t()
return lx

def lora_dropout_up(up: Tensor, x: Tensor, dropout_prob=0.5):
""" B = B⊤ · diag(mB )⊤ , mB ∼ Bern(1 − p)"""
mask = torch.bernoulli(
torch.ones(up.shape[0], device=up.device) * (1 - dropout_prob)
)

# Apply output dimension mask (rows of up-projection)
lx = x @ (up * mask.view(-1, 1)).t()
return lx
83 changes: 52 additions & 31 deletions networks/lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import Tensor
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
from library.network_utils import lora_dropout_down, lora_dropout_up

setup_logging()
import logging
Expand All @@ -45,6 +45,7 @@ def __init__(
dropout=None,
rank_dropout=None,
module_dropout=None,
lora_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.lora_dropout = lora_dropout

self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
Expand All @@ -132,7 +134,11 @@ def forward(self, x):
return org_forwarded

if self.split_dims is None:
lx = self.lora_down(x)
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
if self.lora_dropout is not None and self.training and self.lora_dropout > 0:
lx = lora_dropout_down(self.lora_down.weight, x, dropout_prob=self.lora_dropout)
else:
lx = self.lora_down(x)

# normal dropout
if self.dropout is not None and self.training:
Expand All @@ -153,14 +159,26 @@ def forward(self, x):
else:
scale = self.scale

lx = self.lora_up(lx)
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
if self.lora_dropout is not None and self.training and self.lora_dropout > 0:
lx = lora_dropout_up(self.lora_up.weight, lx, dropout_prob=self.lora_dropout)
else:
lx = self.lora_up(lx)

# LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
if (
self.training
and self.ggpo_sigma is not None
and self.ggpo_beta is not None
and self.combined_weight_norms is not None
and self.grad_norms is not None
):
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + (
self.ggpo_beta * (self.grad_norms**2)
)
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
Expand Down Expand Up @@ -197,24 +215,24 @@ def initialize_norm_cache(self, org_module_weight: Tensor):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller

# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]

# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)

# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)

# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()

# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()

# Free memory
del sampled_weights, weights_float32

Expand All @@ -223,54 +241,54 @@ def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM

for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk

true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()

# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()

# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage

if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")

return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
"true_mean_norm": true_mean_norm,
"estimated_norm": estimated_norm,
"absolute_error": absolute_error,
"relative_error": relative_error,
}


@torch.no_grad()
def update_norms(self):
# Not running GGPO so not currently running update norms
if self.ggpo_beta is None or self.ggpo_sigma is None:
return

# only update norms when we are training
# only update norms when we are training
if self.training is False:
return

module_weights = self.lora_up.weight @ self.lora_down.weight
module_weights.mul(self.scale)

self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
self.combined_weight_norms = torch.sqrt(
(self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True)
)

@torch.no_grad()
def update_grad_norms(self):
Expand All @@ -293,7 +311,6 @@ def update_grad_norms(self):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)


@property
def device(self):
return next(self.parameters()).device
Expand Down Expand Up @@ -544,6 +561,9 @@ def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
lora_dropout = kwargs.get("lora_dropout", None)
if lora_dropout is not None:
lora_dropout = float(lora_dropout)

# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
Expand All @@ -564,7 +584,6 @@ def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)


# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
Expand All @@ -585,6 +604,7 @@ def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
lora_dropout=lora_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
Expand Down Expand Up @@ -696,6 +716,7 @@ def __init__(
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
lora_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
Expand All @@ -722,6 +743,7 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.lora_dropout = lora_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
Expand Down Expand Up @@ -757,7 +779,6 @@ def __init__(
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")


if train_t5xxl:
logger.info(f"train T5XXL as well")

Expand Down Expand Up @@ -876,6 +897,7 @@ def create_modules(
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
lora_dropout=lora_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
Expand All @@ -895,7 +917,7 @@ def create_modules(
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
break

logger.info(f"create LoRA for Text Encoder {index+1}:")
logger.info(f"create LoRA for Text Encoder {index + 1}:")

text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
Expand Down Expand Up @@ -976,7 +998,6 @@ def combined_weight_norms(self) -> Tensor:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])


def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .
84 changes: 84 additions & 0 deletions tests/library/test_network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import pytest
from library.network_utils import lora_dropout_down, lora_dropout_up


@pytest.fixture
def setup_lora_dropout_dimensions():
batch_size = 2
in_dim = 32
lora_dim = 8
out_dim = 16

x_in = torch.randn(batch_size, in_dim)
x_mid = torch.randn(batch_size, lora_dim)

down = torch.randn(lora_dim, in_dim)
up = torch.randn(out_dim, lora_dim)

return {
"batch_size": batch_size,
"in_dim": in_dim,
"lora_dim": lora_dim,
"out_dim": out_dim,
"x_in": x_in,
"x_mid": x_mid,
"down": down,
"up": up,
}


# Tests
def test_lora_dropout_dimensions(setup_lora_dropout_dimensions):
"""Test if output dimensions are correct"""
d = setup_lora_dropout_dimensions

# Apply dropout
mid_out = lora_dropout_down(d["down"], d["x_in"])
final_out = lora_dropout_up(d["up"], mid_out)

# Check dimensions
assert mid_out.shape == (d["batch_size"], d["lora_dim"])
assert final_out.shape == (d["batch_size"], d["out_dim"])


def test_lora_dropout_reproducibility():
"""Test if setting a seed makes dropout reproducible"""
in_dim = 50
lora_dim = 10
batch_size = 3

# Create sample inputs
x_in = torch.randn(batch_size, in_dim)

# Create weight matrix
down = torch.randn(lora_dim, in_dim)

# First run
torch.manual_seed(123)
result1 = lora_dropout_down(down, x_in)

# Second run with same seed
torch.manual_seed(123)
result2 = lora_dropout_down(down, x_in)

# They should be identical
assert torch.allclose(result1, result2)


def test_lora_dropout_full_forward_path(setup_lora_dropout_dimensions):
"""Test a complete LoRA path with dropout"""
torch.manual_seed(456)

d = setup_lora_dropout_dimensions

# Normal forward path without dropout
mid_normal = d["x_in"] @ d["down"].t()
out_normal = mid_normal @ d["up"].t()

# Forward path with dropout
mid_dropout = lora_dropout_down(d["down"], d["x_in"])
out_dropout = lora_dropout_up(d["up"], mid_dropout)

# The outputs should be different due to dropout
assert not torch.allclose(out_normal, out_dropout)