Skip to content

Commit 6386a72

Browse files
committed
Add support for user managed path config
1 parent 0e4582c commit 6386a72

8 files changed

+190
-129
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ outputs
5151
dataset/**
5252
!dataset/**/
5353
!dataset/**/.gitkeep
54+
# models
55+
# data

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ The documentation in this section will be moved to a separate document later.
373373
- Add support for `wandb_run_name`, `log_tracker_name` and `log_tracker_config` parameters under the advanced section.
374374
- Update sd-scripts to v0.8.5
375375
- Improve code
376+
- Add support for custom path defaults. Simply copy the `config example.toml` file found in the root of the repo to `config.toml` and edit the different values to your taste.
376377

377378
### 2024/03/13 (v23.0.11)
378379

config example.toml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copy this file and name it config.toml
2+
# Edit the values to suit your needs
3+
4+
# Default folders location
5+
models_dir = "./models" # Pretrained model name or path
6+
train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images)
7+
output_dir = "./outputs" # Output directory for trained model
8+
reg_data_dir = "./data/reg" # Regularisation directory
9+
logging_dir = "./logs" # Logging directory
10+
config_dir = "./presets" # Load/Save Config file
11+
12+
# Example custom folder location
13+
# models_dir = "e:/models" # Pretrained model name or path

kohya_gui/class_configuration_file.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import gradio as gr
22
import os
3+
import toml
4+
from .common_gui import list_files, scriptdir, create_refresh_button, load_kohya_ss_gui_config
5+
from .custom_logging import setup_logging
36

4-
from .common_gui import list_files, scriptdir, create_refresh_button
7+
# Set up logging
8+
log = setup_logging()
59

610

711
class ConfigurationFile:
@@ -19,11 +23,11 @@ def __init__(self, headless: bool = False, config_dir: str = None):
1923
"""
2024

2125
self.headless = headless
26+
27+
config = load_kohya_ss_gui_config()
2228

2329
# Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory.
24-
self.current_config_dir = (
25-
config_dir if config_dir is not None else os.path.join(scriptdir, "presets")
26-
)
30+
self.current_config_dir = config.get('config_dir', os.path.join(scriptdir, "presets"))
2731

2832
# Initialize the GUI components for configuration.
2933
self.create_config_gui()
@@ -38,7 +42,7 @@ def list_config_dir(self, path: str) -> list:
3842
Returns:
3943
- list: A list of directories.
4044
"""
41-
self.current_config_dir = path
45+
self.current_config_dir = path if not path == "" else "."
4246
# Lists all .json files in the current configuration directory, used for populating dropdown choices.
4347
return list(list_files(self.current_config_dir, exts=[".json"], all=True))
4448

kohya_gui/class_folders.py

+13-27
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
import gradio as gr
22
import os
3-
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
3+
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button, load_kohya_ss_gui_config
44

55
class Folders:
66
"""
77
A class to handle folder operations in the GUI.
88
"""
9-
def __init__(self, finetune: bool = False, data_dir: str = None, output_dir: str = None, logging_dir: str = None, reg_data_dir: str = None, headless: bool = False):
9+
def __init__(self, finetune: bool = False, headless: bool = False):
1010
"""
1111
Initialize the Folders class.
1212
1313
Parameters:
1414
- finetune (bool): Whether to finetune the model.
15-
- data_dir (str): The directory for data.
16-
- output_dir (str): The directory for output.
17-
- logging_dir (str): The directory for logging.
18-
- reg_data_dir (str): The directory for regularization data.
1915
- headless (bool): Whether to run in headless mode.
2016
"""
2117
self.headless = headless
2218
self.finetune = finetune
19+
20+
# Load kohya_ss GUI configs from config.toml if it exist
21+
config = load_kohya_ss_gui_config()
2322

2423
# Set default directories if not provided
25-
self.current_data_dir = data_dir if data_dir is not None else os.path.join(scriptdir, "data")
26-
self.current_output_dir = output_dir if output_dir is not None else os.path.join(scriptdir, "outputs")
27-
self.current_logging_dir = logging_dir if logging_dir is not None else os.path.join(scriptdir, "logs")
28-
self.current_reg_data_dir = reg_data_dir if reg_data_dir is not None else os.path.join(scriptdir, "reg")
24+
self.current_output_dir = config.get('output_dir', os.path.join(scriptdir, "outputs"))
25+
self.current_logging_dir = config.get('logging_dir', os.path.join(scriptdir, "logs"))
26+
self.current_reg_data_dir = config.get('reg_data_dir', os.path.join(scriptdir, "reg"))
2927

3028
# Create directories if they don't exist
3129
self.create_directory_if_not_exists(self.current_output_dir)
@@ -44,18 +42,6 @@ def create_directory_if_not_exists(self, directory: str) -> None:
4442
if directory is not None and directory.strip() != "" and not os.path.exists(directory):
4543
os.makedirs(directory, exist_ok=True)
4644

47-
def list_data_dirs(self, path: str) -> list:
48-
"""
49-
List directories in the data directory.
50-
51-
Parameters:
52-
- path (str): The path to list directories from.
53-
54-
Returns:
55-
- list: A list of directories.
56-
"""
57-
self.current_data_dir = path
58-
return list(list_dirs(path))
5945

6046
def list_output_dirs(self, path: str) -> list:
6147
"""
@@ -67,7 +53,7 @@ def list_output_dirs(self, path: str) -> list:
6753
Returns:
6854
- list: A list of directories.
6955
"""
70-
self.current_output_dir = path
56+
self.current_output_dir = path if not path == "" else "."
7157
return list(list_dirs(path))
7258

7359
def list_logging_dirs(self, path: str) -> list:
@@ -80,7 +66,7 @@ def list_logging_dirs(self, path: str) -> list:
8066
Returns:
8167
- list: A list of directories.
8268
"""
83-
self.current_logging_dir = path
69+
self.current_logging_dir = path if not path == "" else "."
8470
return list(list_dirs(path))
8571

8672
def list_reg_data_dirs(self, path: str) -> list:
@@ -93,7 +79,7 @@ def list_reg_data_dirs(self, path: str) -> list:
9379
Returns:
9480
- list: A list of directories.
9581
"""
96-
self.current_reg_data_dir = path
82+
self.current_reg_data_dir = path if not path == "" else "."
9783
return list(list_dirs(path))
9884

9985
def create_folders_gui(self) -> None:
@@ -131,7 +117,7 @@ def create_folders_gui(self) -> None:
131117
allow_custom_value=True,
132118
)
133119
# Refresh button for regularisation directory
134-
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_data_dirs(self.current_data_dir)}, "open_folder_small")
120+
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small")
135121
# Regularisation directory button
136122
self.reg_data_dir_folder = gr.Button(
137123
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
@@ -173,7 +159,7 @@ def create_folders_gui(self) -> None:
173159
)
174160
# Change event for regularisation directory dropdown
175161
self.reg_data_dir.change(
176-
fn=lambda path: gr.Dropdown(choices=[""] + self.list_data_dirs(path)),
162+
fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)),
177163
inputs=self.reg_data_dir,
178164
outputs=self.reg_data_dir,
179165
show_progress=False,

0 commit comments

Comments
 (0)