-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
85 lines (74 loc) · 2.65 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import sys
import warnings
from pathlib import Path
from prettytable import PrettyTable
import prettytable
BASE_DIR = Path(__file__).resolve(strict=True).parent.parent
sys.path.append(str(BASE_DIR) + "/")
warnings.filterwarnings("ignore")
from utils.train import TorchTrain
from utils.wandb_train import TorchTrain as WanDBTorchTrain
from utils.wandb_train import wandb_model_pipeline
from utils.parser import parameter_parser
import torch
import wandb
def KeyboardInterruption(function):
try:
function()
except KeyboardInterrupt:
print("")
print("Interrupted")
print("exiting ...")
try:
sys.exit(0)
except SystemExit:
os._exit(0)
def main():
parser = parameter_parser()
parameters = {
"epochs" : parser.epochs,
"optimizer" : parser.optimizer,
"dataset" : parser.dataset,
"lr" : parser.lr,
"model" : parser.model,
"hidden_dim1" : parser.hidden_dim1,
"hidden_dim2" : parser.hidden_dim2,
"num_heads" : parser.num_heads,
"heads" : parser.heads,
"wandb_project_name" : parser.wandb_project_name
}
print()
table = PrettyTable()
table.field_names = ['Hyper Parameters and variables', 'Values used']
table.add_rows([
['dataset', parameters['dataset']],
['model', parameters['model']],
['hidden_dim1', parameters['hidden_dim1']],
['hidden_dim2', parameters['hidden_dim2']],
['optimizer', parameters['optimizer']],
['learning rate', parameters['lr']],
['epochs', parameters['epochs']],
['num heads', parameters['num_heads']],
["heads", parameters["heads"].split('-')],
['wandb-project-name', parameters['wandb_project_name']]
]
)
print(table)
print()
if parameters['wandb_project_name'] == "None":
torch_train = TorchTrain(parameters)
model = torch_train.train()
else:
project_name = parameters['wandb_project_name']
wandb.init(project=project_name)
model = wandb_model_pipeline(project=project_name, parameters=parameters)
path_to_save_models = os.path.join(os.getcwd(), "saved_models")
if not os.path.isdir(path_to_save_models):
os.mkdir("saved_models")
model_name = f"{parser.model}_{parser.dataset}_{parser.hidden_dim2}_{parser.num_heads}.pth"
path_to_save_models = os.path.join(path_to_save_models, model_name)
torch.save(model, path_to_save_models)
print(f"Saved the model as {model_name} successfully!!!")
if __name__ == '__main__':
KeyboardInterruption(function=main)