forked from bmaltais/kohya_ss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclass_tensorboard.py
139 lines (117 loc) · 4.69 KB
/
class_tensorboard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import gradio as gr
import subprocess
import time
import webbrowser
try:
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import tensorflow # Attempt to import tensorflow to check if it is installed
visibility = True
except ImportError:
visibility = False
from easygui import msgbox
from threading import Thread, Event
from .custom_logging import setup_logging
from .common_gui import setup_environment
class TensorboardManager:
DEFAULT_TENSORBOARD_PORT = 6006
def __init__(self, logging_dir, headless: bool = False, wait_time=5):
self.logging_dir = logging_dir
self.headless = headless
self.wait_time = wait_time
self.tensorboard_proc = None
self.tensorboard_port = os.environ.get(
"TENSORBOARD_PORT", self.DEFAULT_TENSORBOARD_PORT
)
self.log = setup_logging()
self.thread = None
self.stop_event = Event()
self.gradio_interface()
def get_button_states(self, started=False):
return gr.Button(
visible=visibility and (not started or self.headless)
), gr.Button(visible=visibility and (started or self.headless))
def open_tensorboard_url(self):
tensorboard_url = f"http://localhost:{self.tensorboard_port}"
self.log.info(f"Opening TensorBoard URL in browser: {tensorboard_url}")
webbrowser.open(tensorboard_url)
def start_tensorboard(self, logging_dir=None):
if self.tensorboard_proc is not None:
self.log.info(
"Tensorboard is already running. Terminating existing process before starting new one..."
)
self.stop_tensorboard()
if not os.path.exists(logging_dir) or not os.listdir(logging_dir):
self.log.error(
"Error: logging folder does not exist or does not contain logs."
)
msgbox(msg="Error: logging folder does not exist or does not contain logs.")
return self.get_button_states(started=False)
run_cmd = [
"tensorboard",
"--logdir",
logging_dir,
"--host",
"0.0.0.0",
"--port",
str(self.tensorboard_port),
]
self.log.info(run_cmd)
self.log.info("Starting TensorBoard on port {}".format(self.tensorboard_port))
try:
env = setup_environment()
self.tensorboard_proc = subprocess.Popen(run_cmd, env=env)
except Exception as e:
self.log.error("Failed to start Tensorboard:", e)
return self.get_button_states(started=False)
if not self.headless:
self.stop_event.clear()
time.sleep(self.wait_time)
if not self.stop_event.is_set():
self.thread = Thread(target=self.open_tensorboard_url)
self.thread.start()
return self.get_button_states(started=True)
def stop_tensorboard(self):
if self.tensorboard_proc is not None:
self.log.info("Stopping tensorboard process...")
try:
self.tensorboard_proc.terminate()
self.tensorboard_proc = None
self.log.info("...process stopped")
except Exception as e:
self.log.error("Failed to stop Tensorboard:", e)
if self.thread is not None:
self.stop_event.set()
self.thread.join() # Wait for the thread to finish
self.thread = None
self.log.info("Thread terminated successfully.")
return self.get_button_states(started=False)
def gradio_interface(self):
with gr.Row():
button_start_tensorboard = gr.Button(
value="Start tensorboard",
elem_id="myTensorButton",
visible=visibility,
)
button_stop_tensorboard = gr.Button(
value="Stop tensorboard",
visible=visibility and self.headless,
elem_id="myTensorButtonStop",
)
button_open_tensorboard = gr.Button(
value="Open tensorboard",
elem_id="myTensorButton",
visible=not visibility,
link=f"http://localhost:{self.tensorboard_port}",
)
button_start_tensorboard.click(
self.start_tensorboard,
inputs=[self.logging_dir],
outputs=[button_start_tensorboard, button_stop_tensorboard],
show_progress=False,
)
button_stop_tensorboard.click(
self.stop_tensorboard,
outputs=[button_start_tensorboard, button_stop_tensorboard],
show_progress=False,
)