Skip to content

Commit

Permalink
Add validation of lr scheduler and optimizer arguments (bmaltais#2358)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais authored Apr 20, 2024
1 parent 58e57a3 commit 8234e52
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
### 2024/04/220 (v24.0.6)

- Make start and stop buttons visible in headless
- Add validation for lr and optimizer arguments

### 2024/04/19 (v24.0.5)

Expand Down
15 changes: 14 additions & 1 deletion kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,4 +1498,17 @@ def print_command_and_toml(run_cmd, tmpfilename):
log.info(toml_file.read())
log.info(f"end of toml config file: {tmpfilename}")

save_to_file(command_to_run)
save_to_file(command_to_run)

def validate_args_setting(input_string):
# Regex pattern to handle multiple conditions:
# - Empty string is valid
# - Single or multiple key/value pairs with exactly one space between pairs
# - No spaces around '=' and no spaces within keys or values
pattern = r'^(\S+=\S+)( \S+=\S+)*$|^$'
if re.match(pattern, input_string):
return True
else:
log.info(f"'{input_string}' is not a valid settings string.")
log.info("A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space.")
return False
28 changes: 19 additions & 9 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting,
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -491,19 +492,27 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())
global train_state_value

TRAIN_BUTTON_VISIBLE = [
gr.Button(visible=True),
gr.Button(visible=False or headless),
gr.Textbox(value=train_state_value),
]

if executor.is_running():
log.error("Training is already running. Can't start another training session.")
return TRAIN_BUTTON_VISIBLE

log.info(f"Start training Dreambooth...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

# This function validates files or folder paths. Simply add new variables containing file of folder path
# to validate below
if not validate_paths(
Expand Down Expand Up @@ -808,9 +817,9 @@ def train_model(
for key, value in config_toml_data.items()
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = max_data_loader_n_workers

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))

Expand Down Expand Up @@ -861,7 +870,7 @@ def train_model(
# Run the command

executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)

train_state_value = time.time()

return (
Expand Down Expand Up @@ -950,7 +959,7 @@ def dreambooth_tab(

global executor
executor = CommandExecutor(headless=headless)

with gr.Column(), gr.Group():
with gr.Row():
button_print = gr.Button("Print training command")
Expand Down Expand Up @@ -1102,9 +1111,9 @@ def dreambooth_tab(
outputs=[configuration.config_file_name],
show_progress=False,
)

run_state = gr.Textbox(value=train_state_value, visible=False)

run_state.change(
fn=executor.wait_for_training_to_end,
outputs=[executor.button_run, executor.button_stop_training],
Expand All @@ -1118,7 +1127,8 @@ def dreambooth_tab(
)

executor.button_stop_training.click(
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
executor.kill_command,
outputs=[executor.button_run, executor.button_stop_training],
)

button_print.click(
Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -544,6 +545,14 @@ def train_model(

log.info(f"Start Finetuning...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

if train_dir != "" and not os.path.exists(train_dir):
os.mkdir(train_dir)

Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -679,6 +680,14 @@ def train_model(

log.info(f"Start training LoRA {LoRA_type} ...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -505,6 +506,14 @@ def train_model(

log.info(f"Start training TI...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
Expand Down

0 comments on commit 8234e52

Please sign in to comment.