forked from bmaltais/kohya_ss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclass_accelerate_launch.py
175 lines (157 loc) · 7.5 KB
/
class_accelerate_launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
import os
import shlex
from .class_gui_config import KohyaSSGUIConfig
class AccelerateLaunch:
def __init__(
self,
config: KohyaSSGUIConfig = {},
) -> None:
self.config = config
with gr.Accordion("Resource Selection", open=True):
with gr.Row():
self.mixed_precision = gr.Dropdown(
label="Mixed precision",
choices=["no", "fp16", "bf16", "fp8"],
value=self.config.get("accelerate_launch.mixed_precision", "fp16"),
info="Whether or not to use mixed precision training.",
)
self.num_processes = gr.Number(
label="Number of processes",
value=self.config.get("accelerate_launch.num_processes", 1),
# precision=0,
step=1,
minimum=1,
info="The total number of processes to be launched in parallel.",
)
self.num_machines = gr.Number(
label="Number of machines",
value=self.config.get("accelerate_launch.num_machines", 1),
# precision=0,
step=1,
minimum=1,
info="The total number of machines used in this training.",
)
self.num_cpu_threads_per_process = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label="Number of CPU threads per core",
value=self.config.get(
"accelerate_launch.num_cpu_threads_per_process", 2
),
info="The number of CPU threads per process.",
)
with gr.Row():
self.dynamo_backend = gr.Dropdown(
label="Dynamo backend",
choices=[
"no",
"eager",
"aot_eager",
"inductor",
"aot_ts_nvfuser",
"nvprims_nvfuser",
"cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"tensorrt",
"ipex",
"tvm",
],
value=self.config.get("accelerate_launch.dynamo_backend", "no"),
info="The backend to use for the dynamo JIT compiler.",
)
self.dynamo_mode = gr.Dropdown(
label="Dynamo mode",
choices=[
"default",
"reduce-overhead",
"max-autotune",
],
value=self.config.get("accelerate_launch.dynamo_mode", "default"),
info="Choose a mode to optimize your training with dynamo.",
)
self.dynamo_use_fullgraph = gr.Checkbox(
label="Dynamo use fullgraph",
value=self.config.get("accelerate_launch.dynamo_use_fullgraph", False),
info="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
)
self.dynamo_use_dynamic = gr.Checkbox(
label="Dynamo use dynamic",
value=self.config.get("accelerate_launch.dynamo_use_dynamic", False),
info="Whether to enable dynamic shape tracing.",
)
with gr.Accordion("Hardware Selection", open=True):
with gr.Row():
self.multi_gpu = gr.Checkbox(
label="Multi GPU",
value=self.config.get("accelerate_launch.multi_gpu", False),
info="Whether or not this should launch a distributed GPU training.",
)
with gr.Accordion("Distributed GPUs", open=True):
with gr.Row():
self.gpu_ids = gr.Textbox(
label="GPU IDs",
value=self.config.get("accelerate_launch.gpu_ids", ""),
placeholder="example: 0,1",
info=" What GPUs (by id) should be used for training on this machine as a comma-separated list",
)
self.main_process_port = gr.Number(
label="Main process port",
value=self.config.get("accelerate_launch.main_process_port", 0),
# precision=1,
step=1,
minimum=0,
maximum=65535,
info="The port to use to communicate with the machine of rank 0.",
)
with gr.Row():
self.extra_accelerate_launch_args = gr.Textbox(
label="Extra accelerate launch arguments",
value=self.config.get(
"accelerate_launch.extra_accelerate_launch_args", ""
),
placeholder="example: --same_network --machine_rank 4",
info="List of extra parameters to pass to accelerate launch",
)
def run_cmd(run_cmd: list, **kwargs):
if "dynamo_backend" in kwargs and kwargs.get("dynamo_backend"):
run_cmd.append("--dynamo_backend")
run_cmd.append(kwargs["dynamo_backend"])
if "dynamo_mode" in kwargs and kwargs.get("dynamo_mode"):
run_cmd.append("--dynamo_mode")
run_cmd.append(kwargs["dynamo_mode"])
if "dynamo_use_fullgraph" in kwargs and kwargs.get("dynamo_use_fullgraph"):
run_cmd.append("--dynamo_use_fullgraph")
if "dynamo_use_dynamic" in kwargs and kwargs.get("dynamo_use_dynamic"):
run_cmd.append("--dynamo_use_dynamic")
if "extra_accelerate_launch_args" in kwargs and kwargs["extra_accelerate_launch_args"] != "":
extra_accelerate_launch_args = kwargs["extra_accelerate_launch_args"].replace('"', "")
for arg in extra_accelerate_launch_args.split():
run_cmd.append(shlex.quote(arg))
if "gpu_ids" in kwargs and kwargs.get("gpu_ids") != "":
run_cmd.append("--gpu_ids")
run_cmd.append(shlex.quote(kwargs["gpu_ids"]))
if "main_process_port" in kwargs and kwargs.get("main_process_port", 0) > 0:
run_cmd.append("--main_process_port")
run_cmd.append(str(int(kwargs["main_process_port"])))
if "mixed_precision" in kwargs and kwargs.get("mixed_precision"):
run_cmd.append("--mixed_precision")
run_cmd.append(shlex.quote(kwargs["mixed_precision"]))
if "multi_gpu" in kwargs and kwargs.get("multi_gpu"):
run_cmd.append("--multi_gpu")
if "num_processes" in kwargs and int(kwargs.get("num_processes", 0)) > 0:
run_cmd.append("--num_processes")
run_cmd.append(str(int(kwargs["num_processes"])))
if "num_machines" in kwargs and int(kwargs.get("num_machines", 0)) > 0:
run_cmd.append("--num_machines")
run_cmd.append(str(int(kwargs["num_machines"])))
if (
"num_cpu_threads_per_process" in kwargs
and int(kwargs.get("num_cpu_threads_per_process", 0)) > 0
):
run_cmd.append("--num_cpu_threads_per_process")
run_cmd.append(str(int(kwargs["num_cpu_threads_per_process"])))
return run_cmd