forked from bmaltais/kohya_ss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathverify_lora_gui.py
142 lines (119 loc) · 3.95 KB
/
verify_lora_gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import subprocess
import os
import sys
from .common_gui import (
get_file_path,
scriptdir,
list_files,
create_refresh_button, setup_environment
)
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
def verify_lora(
lora_model,
):
# verify for caption_text_input
if lora_model == "":
log.info("Invalid model A file")
return
# verify if source model exist
if not os.path.isfile(lora_model):
log.info("The provided model A is not a file")
return
run_cmd = [
rf"{PYTHON}",
rf"{scriptdir}/sd-scripts/networks/check_lora_weights.py",
rf"{lora_model}",
]
# run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"'
# Reconstruct the safe command string for display
command_to_run = " ".join(run_cmd)
log.info(f"Executing command: {command_to_run}")
# Set the environment variable for the Python path
env = setup_environment()
# Run the command using subprocess.Popen for asynchronous handling
process = subprocess.Popen(
run_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
output, error = process.communicate()
return (output.decode(), error.decode())
###
# Gradio UI
###
def gradio_verify_lora_tab(headless=False):
current_model_dir = os.path.join(scriptdir, "outputs")
def list_models(path):
nonlocal current_model_dir
current_model_dir = path
return list(list_files(path, exts=[".pt", ".safetensors"], all=True))
with gr.Tab("Verify LoRA"):
gr.Markdown(
"This utility can verify a LoRA network to make sure it is properly trained."
)
lora_ext = gr.Textbox(value="*.pt *.safetensors", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
with gr.Group(), gr.Row():
lora_model = gr.Dropdown(
label="LoRA model (path to the LoRA model to verify)",
interactive=True,
choices=[""] + list_models(current_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
lora_model,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_lora_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lora_model_file.click(
get_file_path,
inputs=[lora_model, lora_ext, lora_ext_name],
outputs=lora_model,
show_progress=False,
)
verify_button = gr.Button("Verify", variant="primary")
lora_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
inputs=lora_model,
outputs=lora_model,
show_progress=False,
)
lora_model_verif_output = gr.Textbox(
label="Output",
placeholder="Verification output",
interactive=False,
lines=1,
max_lines=10,
)
lora_model_verif_error = gr.Textbox(
label="Error",
placeholder="Verification error",
interactive=False,
lines=1,
max_lines=10,
)
verify_button.click(
verify_lora,
inputs=[
lora_model,
],
outputs=[lora_model_verif_output, lora_model_verif_error],
show_progress=False,
)