4
4
from .common_gui import (
5
5
get_folder_path ,
6
6
get_any_file_path ,
7
- scriptdir ,
8
7
list_files ,
9
8
list_dirs ,
10
9
create_refresh_button ,
@@ -30,7 +29,7 @@ def __init__(
30
29
headless : bool = False ,
31
30
finetuning : bool = False ,
32
31
training_type : str = "" ,
33
- config :dict = {},
32
+ config : dict = {},
34
33
) -> None :
35
34
"""
36
35
Initializes the AdvancedTraining class with given settings.
@@ -50,7 +49,9 @@ def __init__(
50
49
# Determine the current directories for VAE and output, falling back to defaults if not specified.
51
50
self .current_vae_dir = self .config .get ("vae_dir" , "./models/vae" )
52
51
self .current_state_dir = self .config .get ("state_dir" , "./outputs" )
53
- self .current_log_tracker_config_dir = self .config .get ("log_tracker_config_dir" , "./logs" )
52
+ self .current_log_tracker_config_dir = self .config .get (
53
+ "log_tracker_config_dir" , "./logs"
54
+ )
54
55
55
56
# Define the behavior for changing noise offset type.
56
57
def noise_offset_type_change (
@@ -406,10 +407,11 @@ def list_state_dirs(path):
406
407
info = "The name of the specific wandb session" ,
407
408
)
408
409
with gr .Group (), gr .Row ():
410
+
409
411
def list_log_tracker_config_files (path ):
410
412
self .current_log_tracker_config_dir = path if not path == "" else "."
411
413
return list (list_files (path , exts = [".json" ], all = True ))
412
-
414
+
413
415
self .log_tracker_name = gr .Textbox (
414
416
label = "Log tracker name" ,
415
417
value = "" ,
@@ -418,7 +420,8 @@ def list_log_tracker_config_files(path):
418
420
)
419
421
self .log_tracker_config = gr .Dropdown (
420
422
label = "Log tracker config" ,
421
- choices = ["" ] + list_log_tracker_config_files (self .current_log_tracker_config_dir ),
423
+ choices = ["" ]
424
+ + list_log_tracker_config_files (self .current_log_tracker_config_dir ),
422
425
value = "" ,
423
426
info = "Path to tracker config file to use for logging" ,
424
427
interactive = True ,
@@ -427,7 +430,10 @@ def list_log_tracker_config_files(path):
427
430
create_refresh_button (
428
431
self .log_tracker_config ,
429
432
lambda : None ,
430
- lambda : {"choices" : ["" ] + list_log_tracker_config_files (self .current_log_tracker_config_dir )},
433
+ lambda : {
434
+ "choices" : ["" ]
435
+ + list_log_tracker_config_files (self .current_log_tracker_config_dir )
436
+ },
431
437
"open_folder_small" ,
432
438
)
433
439
self .log_tracker_config_button = gr .Button (
@@ -439,7 +445,9 @@ def list_log_tracker_config_files(path):
439
445
show_progress = False ,
440
446
)
441
447
self .log_tracker_config .change (
442
- fn = lambda path : gr .Dropdown (choices = ["" ] + list_log_tracker_config_files (path )),
448
+ fn = lambda path : gr .Dropdown (
449
+ choices = ["" ] + list_log_tracker_config_files (path )
450
+ ),
443
451
inputs = self .log_tracker_config ,
444
452
outputs = self .log_tracker_config ,
445
453
show_progress = False ,
0 commit comments