Skip to content

Commit

Permalink
Add server_port and inbrowser support
Browse files Browse the repository at this point in the history
- to all gui scripts
  • Loading branch information
bmaltais committed Feb 10, 2023
1 parent 56d171c commit e5f8ba5
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 156 deletions.
78 changes: 43 additions & 35 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,40 +435,6 @@ def train_model(
save_inference_file(output_dir, v2, v_parameterization, output_name)


def UI(username, password):
css = ''

if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'

interface = gr.Blocks(css=css)

with interface:
with gr.Tab('Dreambooth'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)

# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()


def dreambooth_tab(
train_data_dir=gr.Textbox(),
reg_data_dir=gr.Textbox(),
Expand Down Expand Up @@ -735,6 +701,44 @@ def dreambooth_tab(
)


def UI(**kwargs):
css = ''

if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'

interface = gr.Blocks(css=css)

with interface:
with gr.Tab('Dreambooth'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)

# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)

if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
Expand All @@ -744,7 +748,11 @@ def dreambooth_tab(
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")

args = parser.parse_args()

UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
59 changes: 34 additions & 25 deletions finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,30 +431,6 @@ def remove_doublequote(file_path):
return file_path


def UI(username, password):

css = ''

if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'

interface = gr.Blocks(css=css)

with interface:
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False)

# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()


def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False)
Expand Down Expand Up @@ -708,6 +684,35 @@ def finetune_tab():
)


def UI(**kwargs):

css = ''

if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'

interface = gr.Blocks(css=css)

with interface:
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False)

# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)


if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
Expand All @@ -717,7 +722,11 @@ def finetune_tab():
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")

args = parser.parse_args()

UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
8 changes: 2 additions & 6 deletions gui.bat
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
@echo off

set VENV_DIR=.\venv
set PYTHON=python

call %VENV_DIR%\Scripts\activate.bat

%PYTHON% kohya_gui.py
call venv\Scripts\activate.bat
python.exe kohya_gui.py %*

pause
2 changes: 1 addition & 1 deletion gui.ps1
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
.\venv\Scripts\activate
python.exe kohya_gui.py
python.exe kohya_gui.py $args
20 changes: 12 additions & 8 deletions kohya_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from lora_gui import lora_tab


def UI(username, password, inbrowser, server_port):

def UI(**kwargs):
css = ''

if os.path.exists('./style.css'):
Expand Down Expand Up @@ -47,13 +46,18 @@ def UI(username, password, inbrowser, server_port):
gradio_merge_lora_tab()

# Show the interface
kwargs = {}
if username:
kwargs["auth"] = (username, password)
launch_kwargs = {}
username = kwargs.get('username')
password = kwargs.get('password')
server_port = kwargs.get('server_port', 0)
inbrowser = kwargs.get('inbrowser', False)
if username and password:
launch_kwargs["auth"] = (username, password)
if server_port > 0:
kwargs["server_port"] = server_port
kwargs["inbrowser"] = inbrowser
interface.launch(**kwargs)
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs["inbrowser"] = inbrowser
interface.launch(**launch_kwargs)

if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
Expand Down
6 changes: 4 additions & 2 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,11 @@ def gradio_advanced_training():
label="Dropout caption every n epochs",
value=0
)
caption_dropout_rate = gr.Number(
caption_dropout_rate = gr.Slider(
label="Rate of caption dropout",
value=0
value=0,
minimum=0,
maximum=1
)
with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False)
Expand Down
23 changes: 16 additions & 7 deletions library/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def utilities_tab(
)


def UI(username, password):
def UI(**kwargs):
css = ''

if os.path.exists('./style.css'):
Expand All @@ -50,11 +50,16 @@ def UI(username, password):
utilities_tab()

# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()

launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)


if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
Expand All @@ -65,7 +70,11 @@ def UI(username, password):
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")

args = parser.parse_args()

UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
Loading

0 comments on commit e5f8ba5

Please sign in to comment.