Skip to content

Commit

Permalink
Add block_lr to finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Aug 15, 2023
1 parent c19bdc3 commit 1cd8849
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
19 changes: 12 additions & 7 deletions finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def save_configuration(
save_state,
resume,
gradient_checkpointing,
gradient_accumulation_steps,
gradient_accumulation_steps,block_lr,
mem_eff_attn,
shuffle_caption,
output_name,
Expand Down Expand Up @@ -219,7 +219,7 @@ def open_configuration(
save_state,
resume,
gradient_checkpointing,
gradient_accumulation_steps,
gradient_accumulation_steps,block_lr,
mem_eff_attn,
shuffle_caption,
output_name,
Expand Down Expand Up @@ -338,7 +338,7 @@ def train_model(
save_state,
resume,
gradient_checkpointing,
gradient_accumulation_steps,
gradient_accumulation_steps,block_lr,
mem_eff_attn,
shuffle_caption,
output_name,
Expand Down Expand Up @@ -544,10 +544,9 @@ def train_model(
run_cmd += f' --save_model_as={save_model_as}'
if int(gradient_accumulation_steps) > 1:
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
# if save_state:
# run_cmd += ' --save_state'
# if not resume == '':
# run_cmd += f' --resume={resume}'
if not block_lr == '':
run_cmd += f' --block_lr="{block_lr}"'

if not output_name == '':
run_cmd += f' --output_name="{output_name}"'
if int(max_token_length) > 75:
Expand Down Expand Up @@ -829,6 +828,11 @@ def finetune_tab(headless=False):
gradient_accumulation_steps = gr.Number(
label='Gradient accumulate steps', value='1'
)
block_lr = gr.Textbox(
label='Block LR',
placeholder='(Optional)',
info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3',
)
advanced_training = AdvancedTraining(
headless=headless, finetuning=True
)
Expand Down Expand Up @@ -906,6 +910,7 @@ def finetune_tab(headless=False):
advanced_training.resume,
advanced_training.gradient_checkpointing,
gradient_accumulation_steps,
block_lr,
advanced_training.mem_eff_attn,
advanced_training.shuffle_caption,
output_name,
Expand Down
15 changes: 0 additions & 15 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def save_configuration(
block_lr_zero_threshold,
block_dims,
block_alphas,
block_lr,
conv_block_dims,
conv_block_alphas,
weighted_captions,
Expand Down Expand Up @@ -303,7 +302,6 @@ def open_configuration(
block_lr_zero_threshold,
block_dims,
block_alphas,
block_lr,
conv_block_dims,
conv_block_alphas,
weighted_captions,
Expand Down Expand Up @@ -478,7 +476,6 @@ def train_model(
block_lr_zero_threshold,
block_dims,
block_alphas,
block_lr,
conv_block_dims,
conv_block_alphas,
weighted_captions,
Expand Down Expand Up @@ -803,7 +800,6 @@ def train_model(
'block_lr_zero_threshold',
'block_dims',
'block_alphas',
'block_lr',
'conv_block_dims',
'conv_block_alphas',
'rank_dropout',
Expand Down Expand Up @@ -838,7 +834,6 @@ def train_model(
'block_lr_zero_threshold',
'block_dims',
'block_alphas',
'block_lr',
'conv_block_dims',
'conv_block_alphas',
'rank_dropout',
Expand Down Expand Up @@ -873,7 +868,6 @@ def train_model(
'block_lr_zero_threshold',
'block_dims',
'block_alphas',
'block_lr',
'conv_block_dims',
'conv_block_alphas',
'rank_dropout',
Expand All @@ -897,9 +891,6 @@ def train_model(
if network_args:
run_cmd += f' --network_args{network_args}'

# if not block_lr == '':
# run_cmd += f' --block_lr="{block_lr}"'

if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
Expand Down Expand Up @@ -1474,11 +1465,6 @@ def update_LoRA_settings(LoRA_type):
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.',
)
block_lr = gr.Textbox(
label='Block LR',
placeholder='(Optional)',
info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3',
)
with gr.Tab(label='Conv'):
with gr.Row(visible=True):
conv_block_dims = gr.Textbox(
Expand Down Expand Up @@ -1652,7 +1638,6 @@ def update_LoRA_settings(LoRA_type):
block_lr_zero_threshold,
block_dims,
block_alphas,
block_lr,
conv_block_dims,
conv_block_alphas,
advanced_training.weighted_captions,
Expand Down

0 comments on commit 1cd8849

Please sign in to comment.