@@ -126,8 +126,8 @@ def __init__(
126126 self .mgpo_beta = mgpo_beta
127127
128128 # EMA of gradient magnitudes for adaptive normalization
129- self ._grad_magnitude_ema_down = torch .nn . Parameter ( torch . tensor (1.0 ), requires_grad = False )
130- self ._grad_magnitude_ema_up = torch .nn . Parameter ( torch . tensor (1.0 ), requires_grad = False )
129+ self .register_buffer ( ' _grad_magnitude_ema_down' , torch .tensor (1.0 ), persistent = False )
130+ self .register_buffer ( ' _grad_magnitude_ema_up' , torch .tensor (1.0 ), persistent = False )
131131
132132 self .optimizer : torch .optim .Optimizer | None = None
133133
@@ -337,24 +337,23 @@ def update_grad_norms(self):
337337 def update_gradient_ema (self ):
338338 """
339339 Update EMA of gradient magnitudes for adaptive perturbation normalization
340-
341340 Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
342341 """
343342 if self .mgpo_beta is None :
344343 return
345-
344+
346345 # Update EMA for lora_down gradient magnitude
347346 if self .lora_down .weight .grad is not None :
348347 current_grad_norm = torch .norm (self .lora_down .weight .grad , p = 2 )
349- self ._grad_magnitude_ema_down .data = (
350- self . mgpo_beta * self . _grad_magnitude_ema_down . data + (1 - self .mgpo_beta ) * current_grad_norm
348+ self ._grad_magnitude_ema_down .mul_ ( self . mgpo_beta ). add_ (
349+ current_grad_norm , alpha = (1 - self .mgpo_beta )
351350 )
352-
351+
353352 # Update EMA for lora_up gradient magnitude
354353 if self .lora_up .weight .grad is not None :
355354 current_grad_norm = torch .norm (self .lora_up .weight .grad , p = 2 )
356- self ._grad_magnitude_ema_up .data = (
357- self . mgpo_beta * self . _grad_magnitude_ema_up . data + (1 - self .mgpo_beta ) * current_grad_norm
355+ self ._grad_magnitude_ema_up .mul_ ( self . mgpo_beta ). add_ (
356+ current_grad_norm , alpha = (1 - self .mgpo_beta )
358357 )
359358
360359 def get_mgpo_output_perturbation (self , x : Tensor ) -> Tensor | None :
0 commit comments