Skip to content

Commit 4a44be3

Browse files
authored
feat: add preset selection to Gradio UI (session based) (lllyasviel#1570)
* add preset selection uses meta parsing to set presets in user session (UI elements only) * add LoRA handling * use default config as fallback value * add preset refresh on "Refresh All Files" click * add special handling for default_styles and default_aspect_ratio * sort styles after preset change * code cleanup * download missing models from preset * set default refiner to "None" in preset realistic * use state_is_generating for preset selection change * DRY output parameter handling * feat: add argument --disable-preset-selection useful for cloud provisioning to prevent model switches and keep models loaded * feat: keep prompt when not set in preset, use more robust syntax * fix: add default return values when preset download is disabled mashb1t#20 * feat: add translation for preset label * refactor: unify preset loading methods in config * refactor: code cleanup
1 parent 8baafcd commit 4a44be3

File tree

7 files changed

+132
-69
lines changed

7 files changed

+132
-69
lines changed

args_manager.py

+3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from tempfile import gettempdir
55

66
args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")
7+
78
args_parser.parser.add_argument("--preset", type=str, default=None, help="Apply specified UI preset.")
9+
args_parser.parser.add_argument("--disable-preset-selection", action='store_true',
10+
help="Disables preset selection in Gradio.")
811

912
args_parser.parser.add_argument("--language", type=str, default='default',
1013
help="Translate UI using json files in [language] folder. "

language/en.json

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)": "* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)",
3939
"Setting": "Setting",
4040
"Style": "Style",
41+
"Preset": "Preset",
4142
"Performance": "Performance",
4243
"Speed": "Speed",
4344
"Quality": "Quality",

launch.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def ini_args():
9393
print(f"[Cleanup] Failed to delete content of temp dir.")
9494

9595

96-
def download_models():
96+
def download_models(default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads):
9797
for file_name, url in vae_approx_filenames:
9898
load_file_from_url(url=url, model_dir=config.path_vae_approx, file_name=file_name)
9999

@@ -105,30 +105,32 @@ def download_models():
105105

106106
if args.disable_preset_download:
107107
print('Skipped model download.')
108-
return
108+
return default_model, checkpoint_downloads
109109

110110
if not args.always_download_new_model:
111-
if not os.path.exists(os.path.join(config.paths_checkpoints[0], config.default_base_model_name)):
112-
for alternative_model_name in config.previous_default_models:
111+
if not os.path.exists(os.path.join(config.paths_checkpoints[0], default_model)):
112+
for alternative_model_name in previous_default_models:
113113
if os.path.exists(os.path.join(config.paths_checkpoints[0], alternative_model_name)):
114-
print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].')
114+
print(f'You do not have [{default_model}] but you have [{alternative_model_name}].')
115115
print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, '
116-
f'but you are not using latest models.')
116+
f'but you are not using the latest models.')
117117
print('Use --always-download-new-model to avoid fallback and always get new models.')
118-
config.checkpoint_downloads = {}
119-
config.default_base_model_name = alternative_model_name
118+
checkpoint_downloads = {}
119+
default_model = alternative_model_name
120120
break
121121

122-
for file_name, url in config.checkpoint_downloads.items():
122+
for file_name, url in checkpoint_downloads.items():
123123
load_file_from_url(url=url, model_dir=config.paths_checkpoints[0], file_name=file_name)
124-
for file_name, url in config.embeddings_downloads.items():
124+
for file_name, url in embeddings_downloads.items():
125125
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
126-
for file_name, url in config.lora_downloads.items():
126+
for file_name, url in lora_downloads.items():
127127
load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name)
128128

129-
return
129+
return default_model, checkpoint_downloads
130130

131131

132-
download_models()
132+
config.default_base_model_name, config.checkpoint_downloads = download_models(
133+
config.default_base_model_name, config.previous_default_models, config.checkpoint_downloads,
134+
config.embeddings_downloads, config.lora_downloads)
133135

134136
from webui import *

modules/config.py

+64-39
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,45 @@ def replace_config(old_key, new_key):
9797

9898
try_load_deprecated_user_path_config()
9999

100-
preset = args_manager.args.preset
101100

102-
if isinstance(preset, str):
103-
preset_path = os.path.abspath(f'./presets/{preset}.json')
104-
try:
105-
if os.path.exists(preset_path):
106-
with open(preset_path, "r", encoding="utf-8") as json_file:
107-
config_dict.update(json.load(json_file))
108-
print(f'Loaded preset: {preset_path}')
109-
else:
110-
raise FileNotFoundError
111-
except Exception as e:
112-
print(f'Load preset [{preset_path}] failed')
113-
print(e)
101+
def get_presets():
102+
preset_folder = 'presets'
103+
presets = ['initial']
104+
if not os.path.exists(preset_folder):
105+
print('No presets found.')
106+
return presets
107+
108+
return presets + [f[:f.index('.json')] for f in os.listdir(preset_folder) if f.endswith('.json')]
109+
110+
111+
def try_get_preset_content(preset):
112+
if isinstance(preset, str):
113+
preset_path = os.path.abspath(f'./presets/{preset}.json')
114+
try:
115+
if os.path.exists(preset_path):
116+
with open(preset_path, "r", encoding="utf-8") as json_file:
117+
json_content = json.load(json_file)
118+
print(f'Loaded preset: {preset_path}')
119+
return json_content
120+
else:
121+
raise FileNotFoundError
122+
except Exception as e:
123+
print(f'Load preset [{preset_path}] failed')
124+
print(e)
125+
return {}
114126

115127

128+
try:
129+
with open(os.path.abspath(f'./presets/default.json'), "r", encoding="utf-8") as json_file:
130+
config_dict.update(json.load(json_file))
131+
except Exception as e:
132+
print(f'Load default preset failed.')
133+
print(e)
134+
135+
available_presets = get_presets()
136+
preset = args_manager.args.preset
137+
config_dict.update(try_get_preset_content(preset))
138+
116139
def get_path_output() -> str:
117140
"""
118141
Checking output path argument and overriding default path.
@@ -241,7 +264,7 @@ def init_temp_path(path: str | None, default_path: str) -> str:
241264
default_value=True,
242265
validator=lambda x: isinstance(x, bool)
243266
)
244-
default_base_model_name = get_config_item_or_set_default(
267+
default_base_model_name = default_model = get_config_item_or_set_default(
245268
key='default_model',
246269
default_value='model.safetensors',
247270
validator=lambda x: isinstance(x, str)
@@ -251,7 +274,7 @@ def init_temp_path(path: str | None, default_path: str) -> str:
251274
default_value=[],
252275
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x)
253276
)
254-
default_refiner_model_name = get_config_item_or_set_default(
277+
default_refiner_model_name = default_refiner = get_config_item_or_set_default(
255278
key='default_refiner',
256279
default_value='None',
257280
validator=lambda x: isinstance(x, str)
@@ -451,29 +474,30 @@ def init_temp_path(path: str | None, default_path: str) -> str:
451474

452475
config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [['None', 1.0] for _ in range(default_max_lora_number - len(default_loras))]
453476

454-
possible_preset_keys = [
455-
"default_model",
456-
"default_refiner",
457-
"default_refiner_switch",
458-
"default_loras_min_weight",
459-
"default_loras_max_weight",
460-
"default_loras",
461-
"default_max_lora_number",
462-
"default_cfg_scale",
463-
"default_sample_sharpness",
464-
"default_sampler",
465-
"default_scheduler",
466-
"default_performance",
467-
"default_prompt",
468-
"default_prompt_negative",
469-
"default_styles",
470-
"default_aspect_ratio",
471-
"default_save_metadata_to_images",
472-
"checkpoint_downloads",
473-
"embeddings_downloads",
474-
"lora_downloads",
475-
]
476-
477+
# mapping config to meta parameter
478+
possible_preset_keys = {
479+
"default_model": "base_model",
480+
"default_refiner": "refiner_model",
481+
"default_refiner_switch": "refiner_switch",
482+
"previous_default_models": "previous_default_models",
483+
"default_loras_min_weight": "default_loras_min_weight",
484+
"default_loras_max_weight": "default_loras_max_weight",
485+
"default_loras": "<processed>",
486+
"default_cfg_scale": "guidance_scale",
487+
"default_sample_sharpness": "sharpness",
488+
"default_sampler": "sampler",
489+
"default_scheduler": "scheduler",
490+
"default_overwrite_step": "steps",
491+
"default_performance": "performance",
492+
"default_prompt": "prompt",
493+
"default_prompt_negative": "negative_prompt",
494+
"default_styles": "styles",
495+
"default_aspect_ratio": "resolution",
496+
"default_save_metadata_to_images": "default_save_metadata_to_images",
497+
"checkpoint_downloads": "checkpoint_downloads",
498+
"embeddings_downloads": "embeddings_downloads",
499+
"lora_downloads": "lora_downloads"
500+
}
477501

478502
REWRITE_PRESET = False
479503

@@ -530,10 +554,11 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
530554

531555

532556
def update_files():
533-
global model_filenames, lora_filenames, wildcard_filenames
557+
global model_filenames, lora_filenames, wildcard_filenames, available_presets
534558
model_filenames = get_model_filenames(paths_checkpoints)
535559
lora_filenames = get_model_filenames(paths_loras)
536560
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
561+
available_presets = get_presets()
537562
return
538563

539564

modules/meta_parser.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,8 @@ def parse_meta_from_preset(preset_content):
210210
height = height[:height.index(" ")]
211211
preset_prepared[meta_key] = (width, height)
212212
else:
213-
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[
214-
settings_key] is not None else getattr(modules.config, settings_key)
215-
213+
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[settings_key] is not None else getattr(modules.config, settings_key)
214+
216215
if settings_key == "default_styles" or settings_key == "default_aspect_ratio":
217216
preset_prepared[meta_key] = str(preset_prepared[meta_key])
218217

@@ -570,4 +569,4 @@ def get_exif(metadata: str | None, metadata_scheme: str):
570569
exif[0x0131] = 'Fooocus v' + fooocus_version.version
571570
# 0x927C = MakerNote
572571
exif[0x927C] = metadata_scheme
573-
return exif
572+
return exif

presets/realistic.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"default_model": "realisticStockPhoto_v20.safetensors",
3-
"default_refiner": "",
3+
"default_refiner": "None",
44
"default_refiner_switch": 0.5,
55
"default_loras": [
66
[

webui.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import modules.meta_parser
1616
import args_manager
1717
import copy
18+
import launch
1819

1920
from modules.sdxl_styles import legal_style_names
2021
from modules.private_logger import get_current_html_path
@@ -252,6 +253,11 @@ def trigger_metadata_preview(filepath):
252253

253254
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
254255
with gr.Tab(label='Setting'):
256+
if not args_manager.args.disable_preset_selection:
257+
preset_selection = gr.Radio(label='Preset',
258+
choices=modules.config.available_presets,
259+
value=args_manager.args.preset if args_manager.args.preset else "initial",
260+
interactive=True)
255261
performance_selection = gr.Radio(label='Performance',
256262
choices=flags.Performance.list(),
257263
value=modules.config.default_performance)
@@ -518,13 +524,50 @@ def refresh_files_clicked():
518524
modules.config.update_files()
519525
results = [gr.update(choices=modules.config.model_filenames)]
520526
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
527+
if not args_manager.args.disable_preset_selection:
528+
results += [gr.update(choices=modules.config.available_presets)]
521529
for i in range(modules.config.default_max_lora_number):
522-
results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
530+
results += [gr.update(interactive=True),
531+
gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
523532
return results
524533

525-
refresh_files.click(refresh_files_clicked, [], [base_model, refiner_model] + lora_ctrls,
534+
refresh_files_output = [base_model, refiner_model]
535+
if not args_manager.args.disable_preset_selection:
536+
refresh_files_output += [preset_selection]
537+
refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls,
526538
queue=False, show_progress=False)
527539

540+
state_is_generating = gr.State(False)
541+
542+
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
543+
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
544+
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
545+
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
546+
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
547+
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
548+
549+
if not args_manager.args.disable_preset_selection:
550+
def preset_selection_change(preset, is_generating):
551+
preset_content = modules.config.try_get_preset_content(preset) if preset != 'initial' else {}
552+
preset_prepared = modules.meta_parser.parse_meta_from_preset(preset_content)
553+
554+
default_model = preset_prepared.get('base_model')
555+
previous_default_models = preset_prepared.get('previous_default_models', [])
556+
checkpoint_downloads = preset_prepared.get('checkpoint_downloads', {})
557+
embeddings_downloads = preset_prepared.get('embeddings_downloads', {})
558+
lora_downloads = preset_prepared.get('lora_downloads', {})
559+
560+
preset_prepared['base_model'], preset_prepared['lora_downloads'] = launch.download_models(
561+
default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads)
562+
563+
if 'prompt' in preset_prepared and preset_prepared.get('prompt') == '':
564+
del preset_prepared['prompt']
565+
566+
return modules.meta_parser.load_parameter_button_click(json.dumps(preset_prepared), is_generating)
567+
568+
preset_selection.change(preset_selection_change, inputs=[preset_selection, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
569+
.then(fn=style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False) \
570+
528571
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +
529572
[gr.update(visible=not flags.Performance.has_restricted_features(x))] * 1 +
530573
[gr.update(interactive=not flags.Performance.has_restricted_features(x), value=flags.Performance.has_restricted_features(x))] * 1,
@@ -600,8 +643,6 @@ def inpaint_mode_change(mode):
600643

601644
ctrls += ip_ctrls
602645

603-
state_is_generating = gr.State(False)
604-
605646
def parse_meta(raw_prompt_txt, is_generating):
606647
loaded_json = None
607648
if is_json(raw_prompt_txt):
@@ -617,13 +658,6 @@ def parse_meta(raw_prompt_txt, is_generating):
617658

618659
prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
619660

620-
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
621-
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
622-
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
623-
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
624-
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
625-
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
626-
627661
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False)
628662

629663
def trigger_metadata_import(filepath, state_is_generating):
@@ -637,7 +671,6 @@ def trigger_metadata_import(filepath, state_is_generating):
637671

638672
return modules.meta_parser.load_parameter_button_click(parsed_parameters, state_is_generating)
639673

640-
641674
metadata_import_button.click(trigger_metadata_import, inputs=[metadata_input_image, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
642675
.then(style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False)
643676

0 commit comments

Comments
 (0)