Skip to content

Commit 15136ca

Browse files
committed
Fix _grad_magnitude_ema_up _grad_magnitude_ema_down getting saved to LoRA
1 parent 3f47806 commit 15136ca

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

networks/lora_flux.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def __init__(
126126
self.mgpo_beta = mgpo_beta
127127

128128
# EMA of gradient magnitudes for adaptive normalization
129-
self._grad_magnitude_ema_down = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
130-
self._grad_magnitude_ema_up = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
129+
self.register_buffer('_grad_magnitude_ema_down', torch.tensor(1.0), persistent=False)
130+
self.register_buffer('_grad_magnitude_ema_up', torch.tensor(1.0), persistent=False)
131131

132132
self.optimizer: torch.optim.Optimizer | None = None
133133

@@ -337,24 +337,23 @@ def update_grad_norms(self):
337337
def update_gradient_ema(self):
338338
"""
339339
Update EMA of gradient magnitudes for adaptive perturbation normalization
340-
341340
Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
342341
"""
343342
if self.mgpo_beta is None:
344343
return
345-
344+
346345
# Update EMA for lora_down gradient magnitude
347346
if self.lora_down.weight.grad is not None:
348347
current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2)
349-
self._grad_magnitude_ema_down.data = (
350-
self.mgpo_beta * self._grad_magnitude_ema_down.data + (1 - self.mgpo_beta) * current_grad_norm
348+
self._grad_magnitude_ema_down.mul_(self.mgpo_beta).add_(
349+
current_grad_norm, alpha=(1 - self.mgpo_beta)
351350
)
352-
351+
353352
# Update EMA for lora_up gradient magnitude
354353
if self.lora_up.weight.grad is not None:
355354
current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2)
356-
self._grad_magnitude_ema_up.data = (
357-
self.mgpo_beta * self._grad_magnitude_ema_up.data + (1 - self.mgpo_beta) * current_grad_norm
355+
self._grad_magnitude_ema_up.mul_(self.mgpo_beta).add_(
356+
current_grad_norm, alpha=(1 - self.mgpo_beta)
358357
)
359358

360359
def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None:

0 commit comments

Comments
 (0)