@@ -357,27 +357,37 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
357357 accelerator .print ("prepare optimizer, data loader etc." )
358358
359359 if args .fused_optimizer_groups :
360+ # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
361+ # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
362+ # This balances memory usage and management complexity.
363+
360364 # calculate total number of parameters
361365 n_total_params = sum (len (params ["params" ]) for params in params_to_optimize )
362366 params_per_group = math .ceil (n_total_params / args .fused_optimizer_groups )
363367
364- # split params into groups
368+ # split params into groups, keeping the learning rate the same for all params in a group
369+ # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
365370 grouped_params = []
366371 param_group = []
367372 param_group_lr = - 1
368373 for group in params_to_optimize :
369374 lr = group ["lr" ]
370375 for p in group ["params" ]:
376+ # if the learning rate is different for different params, start a new group
371377 if lr != param_group_lr :
372378 if param_group :
373379 grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
374380 param_group = []
375381 param_group_lr = lr
382+
376383 param_group .append (p )
384+
385+ # if the group has enough parameters, start a new group
377386 if len (param_group ) == params_per_group :
378387 grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
379388 param_group = []
380389 param_group_lr = - 1
390+
381391 if param_group :
382392 grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
383393
@@ -388,7 +398,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
388398 optimizers .append (optimizer )
389399 optimizer = optimizers [0 ] # avoid error in the following code
390400
391- print (len (grouped_params ))
392401 logger .info (f"using { len (optimizers )} optimizers for fused optimizer groups" )
393402
394403 else :
@@ -420,6 +429,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
420429
421430 # lr schedulerを用意する
422431 if args .fused_optimizer_groups :
432+ # prepare lr schedulers for each optimizer
423433 lr_schedulers = [train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes ) for optimizer in optimizers ]
424434 lr_scheduler = lr_schedulers [0 ] # avoid error in the following code
425435 else :
@@ -472,6 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
472482 optimizer , train_dataloader , lr_scheduler = accelerator .prepare (optimizer , train_dataloader , lr_scheduler )
473483
474484 if args .fused_backward_pass :
485+ # use fused optimizer for backward pass: other optimizers will be supported in the future
475486 import library .adafactor_fused
476487
477488 library .adafactor_fused .patch_adafactor_fused (optimizer )
@@ -488,16 +499,20 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
488499 parameter .register_post_accumulate_grad_hook (__grad_hook )
489500
490501 elif args .fused_optimizer_groups :
502+ # prepare for additional optimizers and lr schedulers
491503 for i in range (1 , len (optimizers )):
492504 optimizers [i ] = accelerator .prepare (optimizers [i ])
493505 lr_schedulers [i ] = accelerator .prepare (lr_schedulers [i ])
494506
507+ # counters are used to determine when to step the optimizer
495508 global optimizer_hooked_count
496509 global num_parameters_per_group
497510 global parameter_optimizer_map
511+
498512 optimizer_hooked_count = {}
499513 num_parameters_per_group = [0 ] * len (optimizers )
500514 parameter_optimizer_map = {}
515+
501516 for opt_idx , optimizer in enumerate (optimizers ):
502517 for param_group in optimizer .param_groups :
503518 for parameter in param_group ["params" ]:
@@ -511,7 +526,7 @@ def optimizer_hook(parameter: torch.Tensor):
511526 optimizer_hooked_count [i ] += 1
512527 if optimizer_hooked_count [i ] == num_parameters_per_group [i ]:
513528 optimizers [i ].step ()
514- optimizers [i ].zero_grad ()
529+ optimizers [i ].zero_grad (set_to_none = True )
515530
516531 parameter .register_post_accumulate_grad_hook (optimizer_hook )
517532 parameter_optimizer_map [parameter ] = opt_idx
@@ -593,7 +608,7 @@ def optimizer_hook(parameter: torch.Tensor):
593608 current_step .value = global_step
594609
595610 if args .fused_optimizer_groups :
596- optimizer_hooked_count = {i : 0 for i in range (len (optimizers ))}
611+ optimizer_hooked_count = {i : 0 for i in range (len (optimizers ))} # reset counter for each step
597612
598613 with accelerator .accumulate (* training_models ):
599614 if "latents" in batch and batch ["latents" ] is not None :
@@ -725,14 +740,14 @@ def optimizer_hook(parameter: torch.Tensor):
725740 accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
726741
727742 optimizer .step ()
728- elif args .fused_optimizer_groups :
729- for i in range (1 , len (optimizers )):
730- lr_schedulers [i ].step ()
731-
732- lr_scheduler .step ()
733-
734- if not (args .fused_backward_pass or args .fused_optimizer_groups ):
743+ lr_scheduler .step ()
735744 optimizer .zero_grad (set_to_none = True )
745+ else :
746+ # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
747+ lr_scheduler .step ()
748+ if args .fused_optimizer_groups :
749+ for i in range (1 , len (optimizers )):
750+ lr_schedulers [i ].step ()
736751
737752 # Checks if the accelerator has performed an optimization step behind the scenes
738753 if accelerator .sync_gradients :
0 commit comments