@@ -493,17 +493,24 @@ def train(self, args):
493493 # before resuming make hook for saving/loading to save/load the network weights only
494494 def save_model_hook (models , weights , output_dir ):
495495 # pop weights of other models than network to save only network weights
496- # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
497- if accelerator .is_main_process or args .deepspeed :
496+ if accelerator .is_main_process :
498497 remove_indices = []
499498 for i , model in enumerate (models ):
500499 if not isinstance (model , type (accelerator .unwrap_model (network ))):
501500 remove_indices .append (i )
502501 for i in reversed (remove_indices ):
503- if len (weights ) > i :
504- weights .pop (i )
502+ weights .pop (i )
505503 # print(f"save model hook: {len(weights)} weights will be saved")
506504
505+ # save current ecpoch and step
506+ train_state_file = os .path .join (output_dir , "train_state.json" )
507+ # +1 is needed because the state is saved before current_step is set from global_step
508+ logger .info (f"save train state to { train_state_file } at epoch { current_epoch .value } step { current_step .value + 1 } " )
509+ with open (train_state_file , "w" , encoding = "utf-8" ) as f :
510+ json .dump ({"current_epoch" : current_epoch .value , "current_step" : current_step .value + 1 }, f )
511+
512+ steps_from_state = None
513+
507514 def load_model_hook (models , input_dir ):
508515 # remove models except network
509516 remove_indices = []
@@ -514,6 +521,15 @@ def load_model_hook(models, input_dir):
514521 models .pop (i )
515522 # print(f"load model hook: {len(models)} models will be loaded")
516523
524+ # load current epoch and step to
525+ nonlocal steps_from_state
526+ train_state_file = os .path .join (input_dir , "train_state.json" )
527+ if os .path .exists (train_state_file ):
528+ with open (train_state_file , "r" , encoding = "utf-8" ) as f :
529+ data = json .load (f )
530+ steps_from_state = data ["current_step" ]
531+ logger .info (f"load train state from { train_state_file } : { data } " )
532+
517533 accelerator .register_save_state_pre_hook (save_model_hook )
518534 accelerator .register_load_state_pre_hook (load_model_hook )
519535
@@ -757,7 +773,53 @@ def load_model_hook(models, input_dir):
757773 if key in metadata :
758774 minimum_metadata [key ] = metadata [key ]
759775
760- progress_bar = tqdm (range (args .max_train_steps ), smoothing = 0 , disable = not accelerator .is_local_main_process , desc = "steps" )
776+ # calculate steps to skip when resuming or starting from a specific step
777+ initial_step = 0
778+ if args .initial_epoch is not None or args .initial_step is not None :
779+ # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
780+ if steps_from_state is not None :
781+ logger .warning (
782+ "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
783+ )
784+ if args .initial_step is not None :
785+ initial_step = args .initial_step
786+ else :
787+ # num steps per epoch is calculated by num_processes and gradient_accumulation_steps
788+ initial_step = (args .initial_epoch - 1 ) * math .ceil (
789+ len (train_dataloader ) / accelerator .num_processes / args .gradient_accumulation_steps
790+ )
791+ else :
792+ # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
793+ if steps_from_state is not None :
794+ initial_step = steps_from_state
795+ steps_from_state = None
796+
797+ if initial_step > 0 :
798+ assert (
799+ args .max_train_steps > initial_step
800+ ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: { args .max_train_steps } vs { initial_step } "
801+
802+ progress_bar = tqdm (
803+ range (args .max_train_steps - initial_step ), smoothing = 0 , disable = not accelerator .is_local_main_process , desc = "steps"
804+ )
805+
806+ epoch_to_start = 0
807+ if initial_step > 0 :
808+ if args .skip_until_initial_step :
809+ # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
810+ if not args .resume :
811+ logger .info (
812+ f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
813+ )
814+ logger .info (f"skipping { initial_step } steps / { initial_step } ステップをスキップします" )
815+ initial_step *= args .gradient_accumulation_steps
816+ else :
817+ # if not, only epoch no is skipped for informative purpose
818+ epoch_to_start = initial_step // math .ceil (
819+ len (train_dataloader ) / args .gradient_accumulation_steps
820+ )
821+ initial_step = 0 # do not skip
822+
761823 global_step = 0
762824
763825 noise_scheduler = DDPMScheduler (
@@ -816,15 +878,24 @@ def remove_model(old_ckpt_name):
816878 self .sample_images (accelerator , args , 0 , global_step , accelerator .device , vae , tokenizer , text_encoder , unet )
817879
818880 # training loop
819- for epoch in range (num_train_epochs ):
881+ for skip_epoch in range (epoch_to_start ): # skip epochs
882+ logger .info (f"skipping epoch { skip_epoch + 1 } because initial_step (multiplied) is { initial_step } " )
883+ initial_step -= len (train_dataloader )
884+
885+ for epoch in range (epoch_to_start , num_train_epochs ):
820886 accelerator .print (f"\n epoch { epoch + 1 } /{ num_train_epochs } " )
821887 current_epoch .value = epoch + 1
822888
823889 metadata ["ss_epoch" ] = str (epoch + 1 )
824890
825891 accelerator .unwrap_model (network ).on_epoch_start (text_encoder , unet )
826892
827- for step , batch in enumerate (train_dataloader ):
893+ skipped_dataloader = None
894+ if initial_step > 0 :
895+ skipped_dataloader = accelerator .skip_first_batches (train_dataloader , initial_step - 1 )
896+ initial_step = 1
897+
898+ for step , batch in enumerate (skipped_dataloader or train_dataloader ):
828899 current_step .value = global_step
829900 with accelerator .accumulate (training_model ):
830901 on_step_start (text_encoder , unet )
@@ -1126,6 +1197,25 @@ def setup_parser() -> argparse.ArgumentParser:
11261197 action = "store_true" ,
11271198 help = "do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う" ,
11281199 )
1200+ parser .add_argument (
1201+ "--skip_until_initial_step" ,
1202+ action = "store_true" ,
1203+ help = "skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする" ,
1204+ )
1205+ parser .add_argument (
1206+ "--initial_epoch" ,
1207+ type = int ,
1208+ default = None ,
1209+ help = "initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
1210+ + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる" ,
1211+ )
1212+ parser .add_argument (
1213+ "--initial_step" ,
1214+ type = int ,
1215+ default = None ,
1216+ help = "initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
1217+ + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする" ,
1218+ )
11291219 # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
11301220 # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
11311221 # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
0 commit comments