Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Nerogar committed Oct 13, 2023
1 parent 32b1cea commit bd8fe8a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 60 deletions.
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
/models*
/training_concepts
/training_samples
/training_user_settings
/external
debug.py
train.bat
*.pyc

*.bak

training_user_settings/
71 changes: 38 additions & 33 deletions modules/ui/OptimizerParamsWindow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import math
import os
from tkinter import TclError

import customtkinter as ctk
from modules.util.ui import components

from modules.util.args.TrainArgs import TrainArgs
from modules.util.enum.Optimizer import Optimizer
import math
import json
import os
from modules.util.ui import components


class UserPreferenceUtility:
def __init__(self, file_path="training_user_settings/optimizer_prefs.json"):
Expand All @@ -13,7 +17,6 @@ def __init__(self, file_path="training_user_settings/optimizer_prefs.json"):
if not os.path.exists(directory):
os.mkdir(directory)


def load_preferences(self, optimizer_name):
if os.path.exists(self.file_path):
with open(self.file_path, 'r') as f:
Expand All @@ -36,6 +39,7 @@ def save_preference(self, optimizer_name, key, value):
with open(self.file_path, 'w') as f:
json.dump(prefs, f, indent=4)


class OptimizerParamsWindow(ctk.CTkToplevel):
def __init__(self, parent, ui_state, *args, **kwargs):
self.pref_util = UserPreferenceUtility()
Expand Down Expand Up @@ -202,7 +206,7 @@ def __init__(self, parent, ui_state, *args, **kwargs):
"optimizer_eps": 1e-8,
"optimizer_weight_decay": 0,
"optimizer_decouple": True,
"optimizer_use_bias_correction": False,
"optimizer_use_bias_correction": False,
"optimizer_safeguard_warmup": False,
"optimizer_d0": 1e-6,
"optimizer_d_coef": 1.0,
Expand Down Expand Up @@ -288,15 +292,15 @@ def __init__(self, parent, ui_state, *args, **kwargs):
"optimizer_foreach": False,
"optimizer_maximize": False,
"optimizer_differentiable": False
},
},
"LION": {
"optimizer_beta1": 0.9,
"optimizer_beta2": 0.99,
"optimizer_weight_decay": 0.0,
"optimizer_use_triton": False
},
}

self.grid_rowconfigure(0, weight=1)
self.grid_rowconfigure(1, weight=0)
self.grid_columnconfigure(0, weight=1)
Expand All @@ -309,17 +313,18 @@ def __init__(self, parent, ui_state, *args, **kwargs):
self.frame.grid_columnconfigure(2, minsize=50)
self.frame.grid_columnconfigure(3, weight=0)
self.frame.grid_columnconfigure(4, weight=1)

components.button(self, 1, 0, "ok", self.__ok)
self.button = None
self.main_frame(self.frame)
self.main_frame(self.frame)

def __ok(self):
self.destroy()

def create_dynamic_ui(self, selected_optimizer, master, components, ui_state, defaults=False):

#Lookup for the title and tooltip for a key
# Lookup for the title and tooltip for a key
# @formatter:off
KEY_DETAIL_MAP = {
'optimizer_adam_w_mode': {'title': 'Adam W Mode', 'tooltip': 'Whether to use weight decay correction for Adam optimizer.', 'type': 'bool'},
'optimizer_alpha': {'title': 'Alpha', 'tooltip': 'Smoothing parameter for RMSprop and others.', 'type': 'float'},
Expand Down Expand Up @@ -364,12 +369,13 @@ def create_dynamic_ui(self, selected_optimizer, master, components, ui_state, de
'optimizer_warmup_init': {'title': 'Warmup Initialization', 'tooltip': 'Whether to warm-up the optimizer initialization.', 'type': 'bool'},
'optimizer_weight_decay': {'title': 'Weight Decay', 'tooltip': 'Regularization to prevent overfitting.', 'type': 'float'},
}

# @formatter:on

if not self.winfo_exists(): # check if this window isn't open
return

optimizer_keys = list(self.OPTIMIZER_KEY_MAP[selected_optimizer].keys()) # Extract the keys for the selected optimizer
# Extract the keys for the selected optimizer
optimizer_keys = list(self.OPTIMIZER_KEY_MAP[selected_optimizer].keys())
for idx, key in enumerate(optimizer_keys):
arg_info = KEY_DETAIL_MAP[key]

Expand All @@ -378,7 +384,7 @@ def create_dynamic_ui(self, selected_optimizer, master, components, ui_state, de
type = arg_info['type']

row = math.floor(idx / 2) + 1
col = 3 * (idx % 2)
col = 3 * (idx % 2)

components.label(master, row, col, title, tooltip=tooltip)
override_value = None
Expand All @@ -390,32 +396,33 @@ def create_dynamic_ui(self, selected_optimizer, master, components, ui_state, de
override_value = self.OPTIMIZER_KEY_MAP[selected_optimizer][key]

if type != 'bool':
entry_widget = components.entry(master, row, col+1, ui_state, key, override_value=override_value)
entry_widget = components.entry(master, row, col + 1, ui_state, key, override_value=override_value)
entry_widget.bind("<FocusOut>", lambda event, opt=selected_optimizer, k=key: self.update_user_pref(opt, k, ui_state.vars[k].get()))
else:
switch_widget = components.switch(master, row, col+1, ui_state, key, override_value=override_value)
switch_widget = components.switch(master, row, col + 1, ui_state, key, override_value=override_value)
switch_widget.configure(command=lambda opt=selected_optimizer, k=key: self.update_user_pref(opt, k, ui_state.vars[k].get()))

def update_user_pref(self, optimizer, key, value):
self.pref_util.save_preference(optimizer, key, value)



def main_frame(self, master):

# Optimizer
components.label(master, 0, 0, "Optimizer", tooltip="The type of optimizer")
components.label(master, 0, 0, "Optimizer",
tooltip="The type of optimizer")
components.options(master, 0, 1, [str(x) for x in list(Optimizer)], self.ui_state, "optimizer")

# Defaults Button
components.label(master, 0, 0, "Optimizer Defaults", tooltip="Load default settings for the selected optimizer")
components.button(self.frame, 0, 4, "Load Defaults", self.load_defaults, tooltip="Load default settings for the selected optimizer")
components.label(master, 0, 0, "Optimizer Defaults",
tooltip="Load default settings for the selected optimizer")
components.button(self.frame, 0, 4, "Load Defaults", self.load_defaults,
tooltip="Load default settings for the selected optimizer")

selected_optimizer = self.ui_state.vars['optimizer'].get()

self.ui_state.vars['optimizer'].trace_add('write', self.on_optimizer_change)
self.create_dynamic_ui(selected_optimizer, master, components, self.ui_state)

def on_optimizer_change(self, *args):
selected_optimizer = self.ui_state.vars['optimizer'].get()
user_prefs = self.pref_util.load_preferences(selected_optimizer)
Expand All @@ -428,28 +435,26 @@ def on_optimizer_change(self, *args):

if value_to_set is None:
value_to_set = "None"

self.ui_state.vars[key].set(value_to_set)

if not self.winfo_exists(): # check if this window isn't open
return
self.clear_dynamic_ui(self.frame)
self.create_dynamic_ui(selected_optimizer, self.frame, components, self.ui_state)

def load_defaults(self):
if not self.winfo_exists(): # check if this window isn't open
return
selected_optimizer = self.ui_state.vars['optimizer'].get()
self.clear_dynamic_ui(self.frame)
self.create_dynamic_ui(selected_optimizer, self.frame, components, self.ui_state, defaults=True)


def clear_dynamic_ui(self, master):
try:
for widget in master.winfo_children():
grid_info = widget.grid_info()
if int(grid_info["row"]) >= 1:
widget.destroy()
except _tkinter.TclError as e:
except TclError as e:
pass

Loading

0 comments on commit bd8fe8a

Please sign in to comment.