Skip to content

Commit 5878bd6

Browse files
committed
Move config for gui to seperate class
Add support for more config paths
1 parent 6386a72 commit 5878bd6

15 files changed

+233
-266
lines changed

.gitignore

+2-5
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,12 @@ test/ft
4141
# Temporary requirements
4242
requirements_tmp_for_setup.txt
4343

44-
# Version specific
45-
0.13.3
46-
4744
*.npz
4845
presets/*/user_presets/*
4946
inputs
5047
outputs
5148
dataset/**
5249
!dataset/**/
5350
!dataset/**/.gitkeep
54-
# models
55-
# data
51+
models
52+
data

config example.toml

+10-7
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
# Edit the values to suit your needs
33

44
# Default folders location
5-
models_dir = "./models" # Pretrained model name or path
6-
train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images)
7-
output_dir = "./outputs" # Output directory for trained model
8-
reg_data_dir = "./data/reg" # Regularisation directory
9-
logging_dir = "./logs" # Logging directory
10-
config_dir = "./presets" # Load/Save Config file
5+
models_dir = "./models" # Pretrained model name or path
6+
train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images)
7+
output_dir = "./outputs" # Output directory for trained model
8+
reg_data_dir = "./data/reg" # Regularisation directory
9+
logging_dir = "./logs" # Logging directory
10+
config_dir = "./presets" # Load/Save Config file
11+
log_tracker_config_dir = "./logs" # Log tracker configs directory
12+
state_dir = "./outputs" # Resume from saved training state
13+
vae_dir = "./models/vae" # VAEs folder path
1114

1215
# Example custom folder location
13-
# models_dir = "e:/models" # Pretrained model name or path
16+
# models_dir = "e:/models" # Pretrained model name or path

config.toml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copy this file and name it config.toml
2+
# Edit the values to suit your needs
3+
4+
# Default folders location
5+
models_dir = "./models" # Pretrained model name or path
6+
train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images)
7+
output_dir = "./outputs" # Output directory for trained model
8+
reg_data_dir = "./data/reg" # Regularisation directory
9+
logging_dir = "./logs" # Logging directory
10+
config_dir = "./presets" # Load/Save Config file
11+
log_tracker_config_dir = "./logs" # Log tracker configs directory
12+
state_dir = "e:/models" # Resume from saved training state
13+
vae_dir = "./models/vae" # VAEs folder path
14+
15+
# Example custom folder location
16+
# models_dir = "e:/models" # Pretrained model name or path

kohya_gui.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gradio as gr
22
import os
33
import argparse
4+
from kohya_gui.class_gui_config import KohyaSSGUIConfig
45
from dreambooth_gui import dreambooth_tab
56
from finetune_gui import finetune_tab
67
from textual_inversion_gui import ti_tab
@@ -24,7 +25,7 @@ def UI(**kwargs):
2425

2526
if os.path.exists("./style.css"):
2627
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
27-
log.info("Load CSS...")
28+
log.debug("Load CSS...")
2829
css += file.read() + "\n"
2930

3031
if os.path.exists("./.release"):
@@ -38,6 +39,8 @@ def UI(**kwargs):
3839
interface = gr.Blocks(
3940
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
4041
)
42+
43+
config = KohyaSSGUIConfig()
4144

4245
with interface:
4346
with gr.Tab("Dreambooth"):
@@ -46,13 +49,13 @@ def UI(**kwargs):
4649
reg_data_dir_input,
4750
output_dir_input,
4851
logging_dir_input,
49-
) = dreambooth_tab(headless=headless)
52+
) = dreambooth_tab(headless=headless, config=config)
5053
with gr.Tab("LoRA"):
51-
lora_tab(headless=headless)
54+
lora_tab(headless=headless, config=config)
5255
with gr.Tab("Textual Inversion"):
53-
ti_tab(headless=headless)
56+
ti_tab(headless=headless, config=config)
5457
with gr.Tab("Finetuning"):
55-
finetune_tab(headless=headless)
58+
finetune_tab(headless=headless, config=config)
5659
with gr.Tab("Utilities"):
5760
utilities_tab(
5861
train_data_dir_input=train_data_dir_input,

kohya_gui/class_advanced_training.py

+19-27
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def __init__(
3030
headless: bool = False,
3131
finetuning: bool = False,
3232
training_type: str = "",
33-
default_vae_dir: str = "",
34-
default_output_dir: str = "",
33+
config:dict = {},
3534
) -> None:
3635
"""
3736
Initializes the AdvancedTraining class with given settings.
@@ -46,16 +45,12 @@ def __init__(
4645
self.headless = headless
4746
self.finetuning = finetuning
4847
self.training_type = training_type
48+
self.config = config
4949

5050
# Determine the current directories for VAE and output, falling back to defaults if not specified.
51-
current_vae_dir = (
52-
default_vae_dir if default_vae_dir else os.path.join(scriptdir, "vae")
53-
)
54-
current_state_dir = (
55-
default_output_dir
56-
if default_output_dir
57-
else os.path.join(scriptdir, "outputs")
58-
)
51+
self.current_vae_dir = self.config.get("vae_dir", "./models/vae")
52+
self.current_state_dir = self.config.get("state_dir", "./outputs")
53+
self.current_log_tracker_config_dir = self.config.get("log_tracker_config_dir", "./logs")
5954

6055
# Define the behavior for changing noise offset type.
6156
def noise_offset_type_change(
@@ -95,21 +90,20 @@ def noise_offset_type_change(
9590
self.prior_loss_weight = gr.Number(label="Prior loss weight", value=1.0)
9691

9792
def list_vae_files(path):
98-
nonlocal current_vae_dir
99-
current_vae_dir = path
93+
self.current_vae_dir = path if not path == "" else "."
10094
return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True))
10195

10296
self.vae = gr.Dropdown(
103-
label="VAE (Optional. path to checkpoint of vae to replace for training)",
97+
label="VAE (Optional: Path to checkpoint of vae for training)",
10498
interactive=True,
105-
choices=[""] + list_vae_files(current_vae_dir),
99+
choices=[""] + list_vae_files(self.current_vae_dir),
106100
value="",
107101
allow_custom_value=True,
108102
)
109103
create_refresh_button(
110104
self.vae,
111105
lambda: None,
112-
lambda: {"choices": list_vae_files(current_vae_dir)},
106+
lambda: {"choices": [""] + list_vae_files(self.current_vae_dir)},
113107
"open_folder_small",
114108
)
115109
self.vae_button = gr.Button(
@@ -222,11 +216,6 @@ def full_options_update(full_fp16, full_bf16):
222216
label="Memory efficient attention", value=False
223217
)
224218
with gr.Row():
225-
# This use_8bit_adam element should be removed in a future release as it is no longer used
226-
# use_8bit_adam = gr.Checkbox(
227-
# label='Use 8bit adam', value=False, visible=False
228-
# )
229-
# self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention')
230219
self.xformers = gr.Dropdown(
231220
label="CrossAttention",
232221
choices=["none", "sdpa", "xformers"],
@@ -348,21 +337,20 @@ def full_options_update(full_fp16, full_bf16):
348337
self.save_state = gr.Checkbox(label="Save training state", value=False)
349338

350339
def list_state_dirs(path):
351-
nonlocal current_state_dir
352-
current_state_dir = path
340+
self.current_state_dir = path if not path == "" else "."
353341
return list(list_dirs(path))
354342

355343
self.resume = gr.Dropdown(
356344
label='Resume from saved training state (path to "last-state" state folder)',
357-
choices=[""] + list_state_dirs(current_state_dir),
345+
choices=[""] + list_state_dirs(self.current_state_dir),
358346
value="",
359347
interactive=True,
360348
allow_custom_value=True,
361349
)
362350
create_refresh_button(
363351
self.resume,
364352
lambda: None,
365-
lambda: {"choices": list_state_dirs(current_state_dir)},
353+
lambda: {"choices": [""] + list_state_dirs(self.current_state_dir)},
366354
"open_folder_small",
367355
)
368356
self.resume_button = gr.Button(
@@ -418,6 +406,10 @@ def list_state_dirs(path):
418406
info="The name of the specific wandb session",
419407
)
420408
with gr.Group(), gr.Row():
409+
def list_log_tracker_config_files(path):
410+
self.current_log_tracker_config_dir = path if not path == "" else "."
411+
return list(list_files(path, exts=[".json"], all=True))
412+
421413
self.log_tracker_name = gr.Textbox(
422414
label="Log tracker name",
423415
value="",
@@ -426,7 +418,7 @@ def list_state_dirs(path):
426418
)
427419
self.log_tracker_config = gr.Dropdown(
428420
label="Log tracker config",
429-
choices=[""] + list_state_dirs(current_state_dir),
421+
choices=[""] + list_log_tracker_config_files(self.current_log_tracker_config_dir),
430422
value="",
431423
info="Path to tracker config file to use for logging",
432424
interactive=True,
@@ -435,7 +427,7 @@ def list_state_dirs(path):
435427
create_refresh_button(
436428
self.log_tracker_config,
437429
lambda: None,
438-
lambda: {"choices": list_state_dirs(current_state_dir)},
430+
lambda: {"choices": [""] + list_log_tracker_config_files(self.current_log_tracker_config_dir)},
439431
"open_folder_small",
440432
)
441433
self.log_tracker_config_button = gr.Button(
@@ -447,7 +439,7 @@ def list_state_dirs(path):
447439
show_progress=False,
448440
)
449441
self.log_tracker_config.change(
450-
fn=lambda path: gr.Dropdown(choices=[""] + list_state_dirs(path)),
442+
fn=lambda path: gr.Dropdown(choices=[""] + list_log_tracker_config_files(path)),
451443
inputs=self.log_tracker_config,
452444
outputs=self.log_tracker_config,
453445
show_progress=False,

kohya_gui/class_configuration_file.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import gradio as gr
22
import os
3-
import toml
4-
from .common_gui import list_files, scriptdir, create_refresh_button, load_kohya_ss_gui_config
3+
from .common_gui import list_files, scriptdir, create_refresh_button
54
from .custom_logging import setup_logging
65

76
# Set up logging
@@ -13,7 +12,7 @@ class ConfigurationFile:
1312
A class to handle configuration file operations in the GUI.
1413
"""
1514

16-
def __init__(self, headless: bool = False, config_dir: str = None):
15+
def __init__(self, headless: bool = False, config_dir: str = None, config:dict = {}):
1716
"""
1817
Initialize the ConfigurationFile class.
1918
@@ -24,10 +23,10 @@ def __init__(self, headless: bool = False, config_dir: str = None):
2423

2524
self.headless = headless
2625

27-
config = load_kohya_ss_gui_config()
26+
self.config = config
2827

2928
# Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory.
30-
self.current_config_dir = config.get('config_dir', os.path.join(scriptdir, "presets"))
29+
self.current_config_dir = self.config.get('config_dir', os.path.join(scriptdir, "presets"))
3130

3231
# Initialize the GUI components for configuration.
3332
self.create_config_gui()

kohya_gui/class_dreambooth_gui.py

-110
This file was deleted.

0 commit comments

Comments
 (0)