Skip to content

Commit

Permalink
Add support to load a config without opening the UI to get the file name
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Mar 11, 2023
1 parent d1962d7 commit a65555e
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 18 deletions.
19 changes: 16 additions & 3 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def save_configuration(


def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
Expand Down Expand Up @@ -213,9 +214,13 @@ def open_configuration(
):
# Get list of function parameters and values
parameters = list(locals().items())

ask_for_file = True if ask_for_file.get('label') == 'True' else False

original_file_path = file_path
file_path = get_file_path(file_path)

if ask_for_file:
file_path = get_file_path(file_path)

if not file_path == '' and not file_path == None:
# load variables from JSON file
Expand All @@ -231,7 +236,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
return tuple(values)

Expand Down Expand Up @@ -506,6 +511,7 @@ def dreambooth_tab(
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()

(
Expand Down Expand Up @@ -775,7 +781,14 @@ def dreambooth_tab(

button_open_config.click(
open_configuration,
inputs=[config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)

button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
Expand Down
31 changes: 22 additions & 9 deletions finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def save_configuration(
return file_path


def open_config_file(
def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
Expand Down Expand Up @@ -217,9 +218,13 @@ def open_config_file(
):
# Get list of function parameters and values
parameters = list(locals().items())

ask_for_file = True if ask_for_file.get('label') == 'True' else False

original_file_path = file_path
file_path = get_file_path(file_path)

if ask_for_file:
file_path = get_file_path(file_path)

if not file_path == '' and not file_path == None:
# load variables from JSON file
Expand All @@ -235,7 +240,7 @@ def open_config_file(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
return tuple(values)

Expand Down Expand Up @@ -492,15 +497,16 @@ def remove_doublequote(file_path):


def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False)
dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False)
gr.Markdown('Train a custom model using kohya finetune python code...')

(
button_open_config,
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()

(
Expand Down Expand Up @@ -770,22 +776,29 @@ def finetune_tab():
button_run.click(train_model, inputs=settings_list)

button_open_config.click(
open_config_file,
inputs=[config_file_name] + settings_list,
open_configuration,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)

button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)

button_save_config.click(
save_configuration,
inputs=[dummy_ft_false, config_file_name] + settings_list,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name],
show_progress=False,
)

button_save_as_config.click(
save_configuration,
inputs=[dummy_ft_true, config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name],
show_progress=False,
)
Expand Down
3 changes: 3 additions & 0 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,14 @@ def gradio_config():
placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=True,
)
button_load_config = gr.Button('Load 💾', elem_id='open_folder')
config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name])
return (
button_open_config,
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
)


Expand Down
19 changes: 16 additions & 3 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def save_configuration(


def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
Expand Down Expand Up @@ -239,9 +240,13 @@ def open_configuration(
):
# Get list of function parameters and values
parameters = list(locals().items())

ask_for_file = True if ask_for_file.get('label') == 'True' else False

original_file_path = file_path
file_path = get_file_path(file_path)

if ask_for_file:
file_path = get_file_path(file_path)

if not file_path == '' and not file_path == None:
# load variables from JSON file
Expand All @@ -257,7 +262,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))

# This next section is about making the LoCon parameters visible if LoRA_type = 'Standard'
Expand Down Expand Up @@ -610,6 +615,7 @@ def lora_tab(
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()

(
Expand Down Expand Up @@ -974,7 +980,14 @@ def LoRA_type_change(LoRA_type):

button_open_config.click(
open_configuration,
inputs=[config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False,
)

button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False,
)
Expand Down
19 changes: 16 additions & 3 deletions textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def save_configuration(


def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
Expand Down Expand Up @@ -225,9 +226,13 @@ def open_configuration(
):
# Get list of function parameters and values
parameters = list(locals().items())

ask_for_file = True if ask_for_file.get('label') == 'True' else False

original_file_path = file_path
file_path = get_file_path(file_path)

if ask_for_file:
file_path = get_file_path(file_path)

if not file_path == '' and not file_path == None:
# load variables from JSON file
Expand All @@ -243,7 +248,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
return tuple(values)

Expand Down Expand Up @@ -548,6 +553,7 @@ def ti_tab(
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()

(
Expand Down Expand Up @@ -865,7 +871,14 @@ def ti_tab(

button_open_config.click(
open_configuration,
inputs=[config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)

button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
Expand Down

0 comments on commit a65555e

Please sign in to comment.