File tree Expand file tree Collapse file tree 2 files changed +29
-1
lines changed
Expand file tree Collapse file tree 2 files changed +29
-1
lines changed Original file line number Diff line number Diff line change @@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
32463246 default = None ,
32473247 help = "directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)" ,
32483248 )
3249+ parser .add_argument (
3250+ "--num_last_block_to_freeze" ,
3251+ type = int ,
3252+ default = None ,
3253+ help = "num_last_block_to_freeze" ,
3254+ )
32493255
32503256
32513257def add_optimizer_arguments (parser : argparse .ArgumentParser ):
@@ -5758,6 +5764,21 @@ def sample_image_inference(
57585764 pass
57595765
57605766
5767+ def freeze_blocks (model , num_last_block_to_freeze , block_name = "x_block" ):
5768+
5769+ filtered_blocks = [(name , param ) for name , param in model .named_parameters () if block_name in name ]
5770+ print (f"filtered_blocks: { len (filtered_blocks )} " )
5771+
5772+ num_blocks_to_freeze = min (len (filtered_blocks ), num_last_block_to_freeze )
5773+
5774+ print (f"freeze_blocks: { num_blocks_to_freeze } " )
5775+
5776+ start_freezing_from = max (0 , len (filtered_blocks ) - num_blocks_to_freeze )
5777+
5778+ for i in range (start_freezing_from , len (filtered_blocks )):
5779+ _ , param = filtered_blocks [i ]
5780+ param .requires_grad = False
5781+
57615782# endregion
57625783
57635784
Original file line number Diff line number Diff line change @@ -368,12 +368,19 @@ def train(args):
368368 vae .eval ()
369369 vae .to (accelerator .device , dtype = vae_dtype )
370370
371+ mmdit .requires_grad_ (train_mmdit )
372+ if not train_mmdit :
373+ mmdit .to (accelerator .device , dtype = weight_dtype ) # because of unet is not prepared
374+
375+ if args .num_last_block_to_freeze :
376+ train_util .freeze_blocks (mmdit ,num_last_block_to_freeze = args .num_last_block_to_freeze )
377+
371378 training_models = []
372379 params_to_optimize = []
373380 # if train_unet:
374381 training_models .append (mmdit )
375382 # if block_lrs is None:
376- params_to_optimize .append ({"params" : list (mmdit .parameters ()), "lr" : args .learning_rate })
383+ params_to_optimize .append ({"params" : list (filter ( lambda p : p . requires_grad , mmdit .parameters () )), "lr" : args .learning_rate })
377384 # else:
378385 # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
379386
You can’t perform that action at this time.
0 commit comments