@@ -30,8 +30,7 @@ def __init__(
30
30
headless : bool = False ,
31
31
finetuning : bool = False ,
32
32
training_type : str = "" ,
33
- default_vae_dir : str = "" ,
34
- default_output_dir : str = "" ,
33
+ config :dict = {},
35
34
) -> None :
36
35
"""
37
36
Initializes the AdvancedTraining class with given settings.
@@ -46,16 +45,12 @@ def __init__(
46
45
self .headless = headless
47
46
self .finetuning = finetuning
48
47
self .training_type = training_type
48
+ self .config = config
49
49
50
50
# Determine the current directories for VAE and output, falling back to defaults if not specified.
51
- current_vae_dir = (
52
- default_vae_dir if default_vae_dir else os .path .join (scriptdir , "vae" )
53
- )
54
- current_state_dir = (
55
- default_output_dir
56
- if default_output_dir
57
- else os .path .join (scriptdir , "outputs" )
58
- )
51
+ self .current_vae_dir = self .config .get ("vae_dir" , "./models/vae" )
52
+ self .current_state_dir = self .config .get ("state_dir" , "./outputs" )
53
+ self .current_log_tracker_config_dir = self .config .get ("log_tracker_config_dir" , "./logs" )
59
54
60
55
# Define the behavior for changing noise offset type.
61
56
def noise_offset_type_change (
@@ -95,21 +90,20 @@ def noise_offset_type_change(
95
90
self .prior_loss_weight = gr .Number (label = "Prior loss weight" , value = 1.0 )
96
91
97
92
def list_vae_files (path ):
98
- nonlocal current_vae_dir
99
- current_vae_dir = path
93
+ self .current_vae_dir = path if not path == "" else "."
100
94
return list (list_files (path , exts = [".ckpt" , ".safetensors" ], all = True ))
101
95
102
96
self .vae = gr .Dropdown (
103
- label = "VAE (Optional. path to checkpoint of vae to replace for training)" ,
97
+ label = "VAE (Optional: Path to checkpoint of vae for training)" ,
104
98
interactive = True ,
105
- choices = ["" ] + list_vae_files (current_vae_dir ),
99
+ choices = ["" ] + list_vae_files (self . current_vae_dir ),
106
100
value = "" ,
107
101
allow_custom_value = True ,
108
102
)
109
103
create_refresh_button (
110
104
self .vae ,
111
105
lambda : None ,
112
- lambda : {"choices" : list_vae_files (current_vae_dir )},
106
+ lambda : {"choices" : [ "" ] + list_vae_files (self . current_vae_dir )},
113
107
"open_folder_small" ,
114
108
)
115
109
self .vae_button = gr .Button (
@@ -222,11 +216,6 @@ def full_options_update(full_fp16, full_bf16):
222
216
label = "Memory efficient attention" , value = False
223
217
)
224
218
with gr .Row ():
225
- # This use_8bit_adam element should be removed in a future release as it is no longer used
226
- # use_8bit_adam = gr.Checkbox(
227
- # label='Use 8bit adam', value=False, visible=False
228
- # )
229
- # self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention')
230
219
self .xformers = gr .Dropdown (
231
220
label = "CrossAttention" ,
232
221
choices = ["none" , "sdpa" , "xformers" ],
@@ -348,21 +337,20 @@ def full_options_update(full_fp16, full_bf16):
348
337
self .save_state = gr .Checkbox (label = "Save training state" , value = False )
349
338
350
339
def list_state_dirs (path ):
351
- nonlocal current_state_dir
352
- current_state_dir = path
340
+ self .current_state_dir = path if not path == "" else "."
353
341
return list (list_dirs (path ))
354
342
355
343
self .resume = gr .Dropdown (
356
344
label = 'Resume from saved training state (path to "last-state" state folder)' ,
357
- choices = ["" ] + list_state_dirs (current_state_dir ),
345
+ choices = ["" ] + list_state_dirs (self . current_state_dir ),
358
346
value = "" ,
359
347
interactive = True ,
360
348
allow_custom_value = True ,
361
349
)
362
350
create_refresh_button (
363
351
self .resume ,
364
352
lambda : None ,
365
- lambda : {"choices" : list_state_dirs (current_state_dir )},
353
+ lambda : {"choices" : [ "" ] + list_state_dirs (self . current_state_dir )},
366
354
"open_folder_small" ,
367
355
)
368
356
self .resume_button = gr .Button (
@@ -418,6 +406,10 @@ def list_state_dirs(path):
418
406
info = "The name of the specific wandb session" ,
419
407
)
420
408
with gr .Group (), gr .Row ():
409
+ def list_log_tracker_config_files (path ):
410
+ self .current_log_tracker_config_dir = path if not path == "" else "."
411
+ return list (list_files (path , exts = [".json" ], all = True ))
412
+
421
413
self .log_tracker_name = gr .Textbox (
422
414
label = "Log tracker name" ,
423
415
value = "" ,
@@ -426,7 +418,7 @@ def list_state_dirs(path):
426
418
)
427
419
self .log_tracker_config = gr .Dropdown (
428
420
label = "Log tracker config" ,
429
- choices = ["" ] + list_state_dirs ( current_state_dir ),
421
+ choices = ["" ] + list_log_tracker_config_files ( self . current_log_tracker_config_dir ),
430
422
value = "" ,
431
423
info = "Path to tracker config file to use for logging" ,
432
424
interactive = True ,
@@ -435,7 +427,7 @@ def list_state_dirs(path):
435
427
create_refresh_button (
436
428
self .log_tracker_config ,
437
429
lambda : None ,
438
- lambda : {"choices" : list_state_dirs ( current_state_dir )},
430
+ lambda : {"choices" : [ "" ] + list_log_tracker_config_files ( self . current_log_tracker_config_dir )},
439
431
"open_folder_small" ,
440
432
)
441
433
self .log_tracker_config_button = gr .Button (
@@ -447,7 +439,7 @@ def list_state_dirs(path):
447
439
show_progress = False ,
448
440
)
449
441
self .log_tracker_config .change (
450
- fn = lambda path : gr .Dropdown (choices = ["" ] + list_state_dirs (path )),
442
+ fn = lambda path : gr .Dropdown (choices = ["" ] + list_log_tracker_config_files (path )),
451
443
inputs = self .log_tracker_config ,
452
444
outputs = self .log_tracker_config ,
453
445
show_progress = False ,
0 commit comments