Skip to content

Commit e802a79

Browse files
committed
sd: add a config field to set default image gen options
1 parent fa3b4f5 commit e802a79

File tree

1 file changed

+64
-4
lines changed

1 file changed

+64
-4
lines changed

koboldcpp.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ class embeddings_generation_outputs(ctypes.Structure):
396396
("count", ctypes.c_int),
397397
("data", ctypes.c_char_p)]
398398

399+
400+
399401
def getdirpath():
400402
return os.path.dirname(os.path.realpath(__file__))
401403
def getabspath():
@@ -1791,9 +1793,58 @@ def sd_comfyui_tranform_params(genparams):
17911793
print("Warning: ComfyUI Payload Missing!")
17921794
return genparams
17931795

1796+
def sd_process_meta_fields(fields, config):
1797+
# aliases to match sd.cpp command-line options
1798+
aliases = {
1799+
'cfg-scale': 'cfg_scale',
1800+
'guidance': 'distilled_guidance',
1801+
'sampler': 'sampler_name',
1802+
'sampling-method': 'sampler_name',
1803+
'timestep-shift': 'shifted_timestep',
1804+
}
1805+
fields_dict = {aliases.get(k, k): v for k, v in fields}
1806+
# whitelist accepted parameters
1807+
whitelist = ['scheduler', 'shifted_timestep', 'distilled_guidance']
1808+
if config:
1809+
# note the current UI always set these
1810+
whitelist += ['sampler_name', 'cfg_scale']
1811+
fields_dict = {k: v for k, v in fields_dict.items() if k in whitelist}
1812+
return fields_dict
1813+
1814+
# json with top-level dict
1815+
def sd_parse_meta_field(prompt, config=False):
1816+
jfields = {}
1817+
try:
1818+
jfields = json.loads(prompt)
1819+
except json.JSONDecodeError:
1820+
# accept "field":"value",... without {} (also empty strings)
1821+
try:
1822+
jfields = json.loads('{ ' + prompt + ' }')
1823+
except json.JSONDecodeError:
1824+
print("Warning: couldn't parse meta prompt; it should be valid JSON.")
1825+
if not isinstance(jfields, dict):
1826+
jfields = {}
1827+
kv_dict = sd_process_meta_fields(jfields.items(), config)
1828+
return kv_dict
1829+
1830+
17941831
def sd_generate(genparams):
17951832
global maxctx, args, currentusergenkey, totalgens, pendingabortkey, chatcompl_adapter
17961833

1834+
sdgendefaults = sd_parse_meta_field(args.sdgendefaults or '', config=True)
1835+
params = dict()
1836+
defparams = dict()
1837+
for k, v in sdgendefaults.items():
1838+
if k in ['sampler_name', 'scheduler']:
1839+
# these can be explicitely set to 'default'; process later
1840+
# TODO should we consider values like 'clip_skip=-1' as 'default' too?
1841+
defparams[k] = v
1842+
else:
1843+
params[k] = v
1844+
# apply most of the defaults
1845+
params.update(genparams)
1846+
genparams = params
1847+
17971848
default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
17981849
adapter_obj = genparams.get('adapter', default_adapter)
17991850
forced_negprompt = adapter_obj.get("add_sd_negative_prompt", "")
@@ -1827,8 +1878,12 @@ def sd_generate(genparams):
18271878
seed = tryparseint(genparams.get("seed", -1),-1)
18281879
if seed < 0:
18291880
seed = random.randint(100000, 999999)
1830-
sample_method = genparams.get("sampler_name", "default")
1831-
scheduler = genparams.get("scheduler", "default")
1881+
sample_method = (genparams.get("sampler_name") or "default").lower()
1882+
if sample_method == 'default' and 'sampler_name' in defparams:
1883+
sample_method = (defparams.get("sampler_name") or "default").lower()
1884+
scheduler = (genparams.get("scheduler") or "default").lower()
1885+
if scheduler == 'default' and 'scheduler' in defparams:
1886+
scheduler = (defparams.get("scheduler") or "default").lower()
18321887
clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
18331888
vid_req_frames = tryparseint(genparams.get("frames", 1),1)
18341889
vid_req_frames = 1 if (not vid_req_frames or vid_req_frames < 1) else vid_req_frames
@@ -1871,8 +1926,8 @@ def sd_generate(genparams):
18711926
inputs.width = width
18721927
inputs.height = height
18731928
inputs.seed = seed
1874-
inputs.sample_method = sample_method.lower().encode("UTF-8")
1875-
inputs.scheduler = scheduler.lower().encode("UTF-8")
1929+
inputs.sample_method = sample_method.encode("UTF-8")
1930+
inputs.scheduler = scheduler.encode("UTF-8")
18761931
inputs.clip_skip = clip_skip
18771932
inputs.vid_req_frames = vid_req_frames
18781933
inputs.vid_req_avi = vid_req_avi
@@ -4690,6 +4745,7 @@ def hide_tooltip(event):
46904745
sd_clamped_soft_var = ctk.StringVar(value="0")
46914746
sd_threads_var = ctk.StringVar(value=str(default_threads))
46924747
sd_quant_var = ctk.StringVar(value=sd_quant_choices[0])
4748+
sd_gen_defaults_var = ctk.StringVar()
46934749

46944750
whisper_model_var = ctk.StringVar()
46954751
tts_model_var = ctk.StringVar()
@@ -5465,6 +5521,7 @@ def toggletaesd(a,b,c):
54655521
makecheckbox(images_tab, "Model CPU Offload", sd_offload_cpu_var, 50,padx=8, tooltiptxt="Offload image weights in RAM to save VRAM, swap into VRAM when needed.")
54665522
makecheckbox(images_tab, "VAE on CPU", sd_vae_cpu_var, 50,padx=160, tooltiptxt="Force VAE to CPU only for image generation.")
54675523
makecheckbox(images_tab, "CLIP on GPU", sd_clip_gpu_var, 50,padx=280, tooltiptxt="Put CLIP and T5 to GPU for image generation. Otherwise, CLIP will use CPU.")
5524+
makelabelentry(images_tab, "Default Params:", sd_gen_defaults_var, 52, 280, padx=110, singleline=True, tooltip='Default image generation parameters when not specified by the UI or API.\nSpecified as JSON fields: {"KEY1":"VALUE1", "KEY2":"VALUE2"...}')
54685525

54695526
# audio tab
54705527
audio_tab = tabcontent["Audio"]
@@ -5738,6 +5795,7 @@ def export_vars():
57385795
args.sdloramult = float(sd_loramult_var.get())
57395796
else:
57405797
args.sdlora = ""
5798+
args.sdgendefaults = sd_gen_defaults_var.get()
57415799

57425800
if whisper_model_var.get() != "":
57435801
args.whispermodel = whisper_model_var.get()
@@ -5964,6 +6022,7 @@ def import_vars(dict):
59646022

59656023
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
59666024
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
6025+
sd_gen_defaults_var.set(dict.get("sdgendefaults", ""))
59676026

59686027
whisper_model_var.set(dict["whispermodel"] if ("whispermodel" in dict and dict["whispermodel"]) else "")
59696028

@@ -7797,6 +7856,7 @@ def range_checker(arg: str):
77977856
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")
77987857
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the image LORA model to be applied.", type=float, default=1.0)
77997858
sdparsergroup.add_argument("--sdtiledvae", metavar=('[maxres]'), help="Adjust the automatic VAE tiling trigger for images above this size. 0 disables vae tiling.", type=int, default=default_vae_tile_threshold)
7859+
sdparsergroup.add_argument("--sdgendefaults", metavar=('{"parameter":"value",...}'), help="Sets default parameters for image generation, as a JSON string.", default="")
78007860
whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands')
78017861
whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper .bin model to enable Speech-To-Text transcription.", default="")
78027862

0 commit comments

Comments
 (0)