forked from bmaltais/kohya_ss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkohya_gui.py
148 lines (130 loc) · 4.63 KB
/
kohya_gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import os
import argparse
from kohya_gui.class_gui_config import KohyaSSGUIConfig
from dreambooth_gui import dreambooth_tab
from finetune_gui import finetune_tab
from textual_inversion_gui import ti_tab
from kohya_gui.utilities import utilities_tab
from lora_gui import lora_tab
from kohya_gui.class_lora_tab import LoRATools
from kohya_gui.custom_logging import setup_logging
from kohya_gui.localization_ext import add_javascript
# Set up logging
log = setup_logging()
def UI(**kwargs):
add_javascript(kwargs.get("language"))
css = ""
headless = kwargs.get("headless", False)
log.info(f"headless: {headless}")
if os.path.exists("./style.css"):
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
log.debug("Load CSS...")
css += file.read() + "\n"
if os.path.exists("./.release"):
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
release = file.read()
if os.path.exists("./README.md"):
with open(os.path.join("./README.md"), "r", encoding="utf8") as file:
README = file.read()
interface = gr.Blocks(
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
)
config = KohyaSSGUIConfig()
with interface:
with gr.Tab("Dreambooth"):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab(headless=headless, config=config)
with gr.Tab("LoRA"):
lora_tab(headless=headless, config=config)
with gr.Tab("Textual Inversion"):
ti_tab(headless=headless, config=config)
with gr.Tab("Finetuning"):
finetune_tab(headless=headless, config=config)
with gr.Tab("Utilities"):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
headless=headless,
)
with gr.Tab("LoRA"):
_ = LoRATools(headless=headless)
with gr.Tab("About"):
gr.Markdown(f"kohya_ss GUI release {release}")
with gr.Tab("README"):
gr.Markdown(README)
htmlStr = f"""
<html>
<body>
<div class="ver-class">{release}</div>
</body>
</html>
"""
gr.HTML(htmlStr)
# Show the interface
launch_kwargs = {}
username = kwargs.get("username")
password = kwargs.get("password")
server_port = kwargs.get("server_port", 0)
inbrowser = kwargs.get("inbrowser", False)
share = kwargs.get("share", False)
server_name = kwargs.get("listen")
launch_kwargs["server_name"] = server_name
if username and password:
launch_kwargs["auth"] = (username, password)
if server_port > 0:
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs["inbrowser"] = inbrowser
if share:
launch_kwargs["share"] = share
launch_kwargs["debug"] = True
interface.launch(**launch_kwargs)
if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument(
"--listen",
type=str,
default="127.0.0.1",
help="IP to listen on for connections to Gradio",
)
parser.add_argument(
"--username", type=str, default="", help="Username for authentication"
)
parser.add_argument(
"--password", type=str, default="", help="Password for authentication"
)
parser.add_argument(
"--server_port",
type=int,
default=0,
help="Port to run the server listener on",
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
parser.add_argument(
"--headless", action="store_true", help="Is the server headless"
)
parser.add_argument(
"--language", type=str, default=None, help="Set custom language"
)
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
args = parser.parse_args()
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
share=args.share,
listen=args.listen,
headless=args.headless,
language=args.language,
)