Skip to content

Commit af46ce4

Browse files
committed
Update LoRA GUI
Various improvements
1 parent ee2499d commit af46ce4

6 files changed

+61
-34
lines changed

README.md

+25-5
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,38 @@ You can find the dreambooth solution spercific [Dreambooth README](README_dreamb
1212

1313
You can find the finetune solution spercific [Finetune README](README_finetune.md)
1414

15+
## LoRA
16+
17+
You can create LoRA network by running the dedicated GUI with:
18+
19+
```
20+
python lora_gui.py
21+
```
22+
23+
or via the all in one GUI:
24+
25+
```
26+
python kahya_gui.py
27+
```
28+
29+
Once you have created the LoRA network you can generate images via auto1111 by installing the extension found here: https://github.com/kohya-ss/sd-webui-additional-networks
30+
1531
## Change history
1632

17-
* 12/30 (v19) update:
33+
* 2023/01/01 (v19.1) update:
34+
- merge kohys_ss upstream code updates
35+
- rework Dreambooth LoRA GUI
36+
- fix bug where LoRA network weights were not loaded to properly resume training
37+
* 2022/12/30 (v19) update:
1838
- support for LoRA network training in kohya_gui.py.
19-
* 12/23 (v18.8) update:
39+
* 2022/12/23 (v18.8) update:
2040
- Fix for conversion tool issue when the source was an sd1.x diffuser model
2141
- Other minor code and GUI fix
22-
* 12/22 (v18.7) update:
42+
* 2022/12/22 (v18.7) update:
2343
- Merge dreambooth and finetune is a common GUI
2444
- General bug fixes and code improvements
25-
* 12/21 (v18.6.1) update:
45+
* 2022/12/21 (v18.6.1) update:
2646
- fix issue with dataset balancing when the number of detected images in the folder is 0
2747

28-
* 12/21 (v18.6) update:
48+
* 2022/12/21 (v18.6) update:
2949
- add optional GUI authentication support via: `python fine_tune.py --username=<name> --password=<password>`

dreambooth_gui.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
get_folder_path,
1616
remove_doublequote,
1717
get_file_path,
18+
get_any_file_path,
1819
get_saveasfile_path,
1920
)
2021
from library.dreambooth_folder_creation_gui import (
@@ -236,7 +237,7 @@ def train_model(
236237
seed,
237238
num_cpu_threads_per_process,
238239
cache_latent,
239-
caption_extention,
240+
caption_extension,
240241
enable_bucket,
241242
gradient_checkpointing,
242243
full_fp16,
@@ -396,7 +397,8 @@ def save_inference_file(output_dir, v2, v_parameterization):
396397
run_cmd += f' --seed={seed}'
397398
run_cmd += f' --save_precision={save_precision}'
398399
run_cmd += f' --logging_dir={logging_dir}'
399-
run_cmd += f' --caption_extention={caption_extention}'
400+
if not caption_extension == '':
401+
run_cmd += f' --caption_extension={caption_extension}'
400402
if not stop_text_encoder_training == 0:
401403
run_cmd += (
402404
f' --stop_text_encoder_training={stop_text_encoder_training}'
@@ -542,7 +544,7 @@ def dreambooth_tab(
542544
document_symbol, elem_id='open_folder_small'
543545
)
544546
pretrained_model_name_or_path_fille.click(
545-
get_file_path,
547+
get_any_file_path,
546548
inputs=[pretrained_model_name_or_path_input],
547549
outputs=pretrained_model_name_or_path_input,
548550
)

examples/word_frequency.ps1

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
$txt_files_folder = "D:\dreambooth\training_twq\mad_hatter\all"
2-
$txt_prefix_to_ignore = "asd"
3-
$txt_postfix_ti_ignore = "asd"
1+
$txt_files_folder = "D:\dataset\metart_g1\img\100_asd girl"
2+
$txt_prefix_to_ignore = "asds"
3+
$txt_postfix_ti_ignore = "asds"
44

55
# Should not need to touch anything below
66

77
# (Get-Content $txt_files_folder"\*.txt" ).Replace(",", "") -Split '\W' | Group-Object -NoElement | Sort-Object -Descending -Property Count
88

9-
$combined_txt = Get-Content $txt_files_folder"\*.txt"
9+
$combined_txt = Get-Content $txt_files_folder"\*.cap"
1010
$combined_txt = $combined_txt.Replace(",", "")
1111
$combined_txt = $combined_txt.Replace("$txt_prefix_to_ignore", "")
1212
$combined_txt = $combined_txt.Replace("$txt_postfix_ti_ignore", "") -Split '\W' | Group-Object -NoElement | Sort-Object -Descending -Property Count

finetune_gui.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from library.common_gui import (
1010
get_folder_path,
1111
get_file_path,
12+
get_any_file_path,
1213
get_saveasfile_path,
1314
)
1415
from library.utilities import utilities_tab
@@ -436,7 +437,7 @@ def finetune_tab():
436437
document_symbol, elem_id='open_folder_small'
437438
)
438439
pretrained_model_name_or_path_file.click(
439-
get_file_path,
440+
get_any_file_path,
440441
inputs=pretrained_model_name_or_path_input,
441442
outputs=pretrained_model_name_or_path_input,
442443
)

library/basic_caption_gui.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def gradio_basic_caption_gui_tab():
7575
)
7676
with gr.Row():
7777
prefix = gr.Textbox(
78-
label='Prefix to add to txt caption',
78+
label='Prefix to add to caption',
7979
placeholder='(Optional)',
8080
interactive=True,
8181
)
@@ -85,7 +85,7 @@ def gradio_basic_caption_gui_tab():
8585
interactive=True,
8686
)
8787
postfix = gr.Textbox(
88-
label='Postfix to add to txt caption',
88+
label='Postfix to add to caption',
8989
placeholder='(Optional)',
9090
interactive=True,
9191
)

lora_gui.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def save_configuration(
6464
shuffle_caption,
6565
save_state,
6666
resume,
67-
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
67+
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
6868
):
6969
original_file_path = file_path
7070

@@ -118,7 +118,8 @@ def save_configuration(
118118
'prior_loss_weight': prior_loss_weight,
119119
'text_encoder_lr': text_encoder_lr,
120120
'unet_lr': unet_lr,
121-
'network_dim': network_dim
121+
'network_dim': network_dim,
122+
'lora_network_weights': lora_network_weights,
122123
}
123124

124125
# Save the data to the selected file
@@ -160,7 +161,7 @@ def open_configuration(
160161
shuffle_caption,
161162
save_state,
162163
resume,
163-
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
164+
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
164165
):
165166

166167
original_file_path = file_path
@@ -216,6 +217,7 @@ def open_configuration(
216217
my_data.get('text_encoder_lr', text_encoder_lr),
217218
my_data.get('unet_lr', unet_lr),
218219
my_data.get('network_dim', network_dim),
220+
my_data.get('lora_network_weights', lora_network_weights),
219221
)
220222

221223

@@ -250,7 +252,7 @@ def train_model(
250252
shuffle_caption,
251253
save_state,
252254
resume,
253-
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
255+
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
254256
):
255257
def save_inference_file(output_dir, v2, v_parameterization):
256258
# Copy inference model for v2 if required
@@ -432,6 +434,7 @@ def save_inference_file(output_dir, v2, v_parameterization):
432434
# elif network_train == 'Unet only':
433435
# run_cmd += f' --network_train_unet_only'
434436
run_cmd += f' --network_dim={network_dim}'
437+
run_cmd += f' --network_weights={lora_network_weights}'
435438

436439

437440
print(run_cmd)
@@ -568,7 +571,7 @@ def lora_tab(
568571
document_symbol, elem_id='open_folder_small'
569572
)
570573
pretrained_model_name_or_path_file.click(
571-
get_file_path,
574+
get_any_file_path,
572575
inputs=[pretrained_model_name_or_path_input],
573576
outputs=pretrained_model_name_or_path_input,
574577
)
@@ -602,19 +605,7 @@ def lora_tab(
602605
],
603606
value='same as source model',
604607
)
605-
with gr.Row():
606-
lora_network_weights = gr.Textbox(
607-
label='LoRA network weights',
608-
placeholder='{Optional) Path to existing LoRA network weights to resume training}',
609-
)
610-
lora_network_weights_file = gr.Button(
611-
document_symbol, elem_id='open_folder_small'
612-
)
613-
lora_network_weights_file.click(
614-
get_any_file_path,
615-
inputs=[lora_network_weights],
616-
outputs=lora_network_weights,
617-
)
608+
618609
with gr.Row():
619610
v2_input = gr.Checkbox(label='v2', value=True)
620611
v_parameterization_input = gr.Checkbox(
@@ -699,6 +690,19 @@ def lora_tab(
699690
outputs=[logging_dir_input],
700691
)
701692
with gr.Tab('Training parameters'):
693+
with gr.Row():
694+
lora_network_weights = gr.Textbox(
695+
label='LoRA network weights',
696+
placeholder='{Optional) Path to existing LoRA network weights to resume training',
697+
)
698+
lora_network_weights_file = gr.Button(
699+
document_symbol, elem_id='open_folder_small'
700+
)
701+
lora_network_weights_file.click(
702+
get_any_file_path,
703+
inputs=[lora_network_weights],
704+
outputs=lora_network_weights,
705+
)
702706
with gr.Row():
703707
# learning_rate_input = gr.Textbox(label='Learning rate', value=1e-4, visible=False)
704708
lr_scheduler_input = gr.Dropdown(
@@ -874,7 +878,7 @@ def lora_tab(
874878
shuffle_caption,
875879
save_state,
876880
resume,
877-
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
881+
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
878882
]
879883

880884
button_open_config.click(

0 commit comments

Comments
 (0)