Skip to content

Commit 4f59a28

Browse files
committed
Update config code
1 parent b2f9c22 commit 4f59a28

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ dataset/**
5050
!dataset/**/.gitkeep
5151
models
5252
data
53+
config.toml

kohya_gui/class_advanced_training.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from .common_gui import (
55
get_folder_path,
66
get_any_file_path,
7-
scriptdir,
87
list_files,
98
list_dirs,
109
create_refresh_button,
@@ -30,7 +29,7 @@ def __init__(
3029
headless: bool = False,
3130
finetuning: bool = False,
3231
training_type: str = "",
33-
config:dict = {},
32+
config: dict = {},
3433
) -> None:
3534
"""
3635
Initializes the AdvancedTraining class with given settings.
@@ -50,7 +49,9 @@ def __init__(
5049
# Determine the current directories for VAE and output, falling back to defaults if not specified.
5150
self.current_vae_dir = self.config.get("vae_dir", "./models/vae")
5251
self.current_state_dir = self.config.get("state_dir", "./outputs")
53-
self.current_log_tracker_config_dir = self.config.get("log_tracker_config_dir", "./logs")
52+
self.current_log_tracker_config_dir = self.config.get(
53+
"log_tracker_config_dir", "./logs"
54+
)
5455

5556
# Define the behavior for changing noise offset type.
5657
def noise_offset_type_change(
@@ -406,10 +407,11 @@ def list_state_dirs(path):
406407
info="The name of the specific wandb session",
407408
)
408409
with gr.Group(), gr.Row():
410+
409411
def list_log_tracker_config_files(path):
410412
self.current_log_tracker_config_dir = path if not path == "" else "."
411413
return list(list_files(path, exts=[".json"], all=True))
412-
414+
413415
self.log_tracker_name = gr.Textbox(
414416
label="Log tracker name",
415417
value="",
@@ -418,7 +420,8 @@ def list_log_tracker_config_files(path):
418420
)
419421
self.log_tracker_config = gr.Dropdown(
420422
label="Log tracker config",
421-
choices=[""] + list_log_tracker_config_files(self.current_log_tracker_config_dir),
423+
choices=[""]
424+
+ list_log_tracker_config_files(self.current_log_tracker_config_dir),
422425
value="",
423426
info="Path to tracker config file to use for logging",
424427
interactive=True,
@@ -427,7 +430,10 @@ def list_log_tracker_config_files(path):
427430
create_refresh_button(
428431
self.log_tracker_config,
429432
lambda: None,
430-
lambda: {"choices": [""] + list_log_tracker_config_files(self.current_log_tracker_config_dir)},
433+
lambda: {
434+
"choices": [""]
435+
+ list_log_tracker_config_files(self.current_log_tracker_config_dir)
436+
},
431437
"open_folder_small",
432438
)
433439
self.log_tracker_config_button = gr.Button(
@@ -439,7 +445,9 @@ def list_log_tracker_config_files(path):
439445
show_progress=False,
440446
)
441447
self.log_tracker_config.change(
442-
fn=lambda path: gr.Dropdown(choices=[""] + list_log_tracker_config_files(path)),
448+
fn=lambda path: gr.Dropdown(
449+
choices=[""] + list_log_tracker_config_files(path)
450+
),
443451
inputs=self.log_tracker_config,
444452
outputs=self.log_tracker_config,
445453
show_progress=False,

kohya_gui/class_gui_config.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@ def get(self, key, default=None):
4343
Returns:
4444
The value associated with the key, or the default value if the key is not found.
4545
"""
46-
if key not in self.config:
47-
log.debug(f"Key '{key}' not found in configuration. Returning default value.")
48-
return self.config.get(key, default)
46+
# Split the key into a list of keys if it contains a dot (.)
47+
keys = key.split(".")
48+
# Initialize `data` with the entire configuration data
49+
data = self.config
50+
51+
# Iterate over the keys to access nested values
52+
for k in keys:
53+
log.debug(k)
54+
# If the key is not found in the current data, return the default value
55+
if k not in data:
56+
log.debug(f"Key '{key}' not found in configuration. Returning default value.")
57+
return default
58+
59+
# Update `data` to the value associated with the current key
60+
data = data.get(k)
61+
62+
# Return the final value
63+
log.debug(f"Returned {data}")
64+
return data

0 commit comments

Comments
 (0)