From 1cd884955923f4109ec586a3cc2d70f02277caca Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 15 Aug 2023 19:14:01 -0400 Subject: [PATCH] Add block_lr to finetune --- finetune_gui.py | 19 ++++++++++++------- lora_gui.py | 15 --------------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/finetune_gui.py b/finetune_gui.py index 5d6c784af..da94a09d0 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -93,7 +93,7 @@ def save_configuration( save_state, resume, gradient_checkpointing, - gradient_accumulation_steps, + gradient_accumulation_steps,block_lr, mem_eff_attn, shuffle_caption, output_name, @@ -219,7 +219,7 @@ def open_configuration( save_state, resume, gradient_checkpointing, - gradient_accumulation_steps, + gradient_accumulation_steps,block_lr, mem_eff_attn, shuffle_caption, output_name, @@ -338,7 +338,7 @@ def train_model( save_state, resume, gradient_checkpointing, - gradient_accumulation_steps, + gradient_accumulation_steps,block_lr, mem_eff_attn, shuffle_caption, output_name, @@ -544,10 +544,9 @@ def train_model( run_cmd += f' --save_model_as={save_model_as}' if int(gradient_accumulation_steps) > 1: run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' - # if save_state: - # run_cmd += ' --save_state' - # if not resume == '': - # run_cmd += f' --resume={resume}' + if not block_lr == '': + run_cmd += f' --block_lr="{block_lr}"' + if not output_name == '': run_cmd += f' --output_name="{output_name}"' if int(max_token_length) > 75: @@ -829,6 +828,11 @@ def finetune_tab(headless=False): gradient_accumulation_steps = gr.Number( label='Gradient accumulate steps', value='1' ) + block_lr = gr.Textbox( + label='Block LR', + placeholder='(Optional)', + info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3', + ) advanced_training = AdvancedTraining( headless=headless, finetuning=True ) @@ -906,6 +910,7 @@ def finetune_tab(headless=False): advanced_training.resume, advanced_training.gradient_checkpointing, gradient_accumulation_steps, + block_lr, advanced_training.mem_eff_attn, advanced_training.shuffle_caption, output_name, diff --git a/lora_gui.py b/lora_gui.py index 39ab04ff9..3e8436f3a 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -150,7 +150,6 @@ def save_configuration( block_lr_zero_threshold, block_dims, block_alphas, - block_lr, conv_block_dims, conv_block_alphas, weighted_captions, @@ -303,7 +302,6 @@ def open_configuration( block_lr_zero_threshold, block_dims, block_alphas, - block_lr, conv_block_dims, conv_block_alphas, weighted_captions, @@ -478,7 +476,6 @@ def train_model( block_lr_zero_threshold, block_dims, block_alphas, - block_lr, conv_block_dims, conv_block_alphas, weighted_captions, @@ -803,7 +800,6 @@ def train_model( 'block_lr_zero_threshold', 'block_dims', 'block_alphas', - 'block_lr', 'conv_block_dims', 'conv_block_alphas', 'rank_dropout', @@ -838,7 +834,6 @@ def train_model( 'block_lr_zero_threshold', 'block_dims', 'block_alphas', - 'block_lr', 'conv_block_dims', 'conv_block_alphas', 'rank_dropout', @@ -873,7 +868,6 @@ def train_model( 'block_lr_zero_threshold', 'block_dims', 'block_alphas', - 'block_lr', 'conv_block_dims', 'conv_block_alphas', 'rank_dropout', @@ -897,9 +891,6 @@ def train_model( if network_args: run_cmd += f' --network_args{network_args}' - # if not block_lr == '': - # run_cmd += f' --block_lr="{block_lr}"' - if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): run_cmd += f' --text_encoder_lr={text_encoder_lr}' @@ -1474,11 +1465,6 @@ def update_LoRA_settings(LoRA_type): placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.', ) - block_lr = gr.Textbox( - label='Block LR', - placeholder='(Optional)', - info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3', - ) with gr.Tab(label='Conv'): with gr.Row(visible=True): conv_block_dims = gr.Textbox( @@ -1652,7 +1638,6 @@ def update_LoRA_settings(LoRA_type): block_lr_zero_threshold, block_dims, block_alphas, - block_lr, conv_block_dims, conv_block_alphas, advanced_training.weighted_captions,