-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
96 lines (74 loc) · 3.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
# coding: utf-8
# Some part of the code was referenced from:
# https://github.com/pytorch/examples/tree/master/word_language_model
import logging
import argparse
import json
import time
from data_utils import Corpus, create_parameter_grid
import os
from train import LanguageModelTrainer
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description='MatsuLM')
parser.add_argument('--sacred_mongo', type=str, default='',
help='MongoDB url to save the Sacred experiment parameters and results')
parser.add_argument('--data', type=str, default='',
help='Path to the training, testing, and validation data ("data/example")')
args = parser.parse_args()
def hyperparameter_tune_language_model(data_path, sacred_mongo=''):
parameters = {
'model': {
'num_layers': 1,
'bidirectional': True,
'embed_size': 100,
'hidden_size': 256,
'init_scale': 0,
'init_bias': 0,
'dropout': 0,
},
'log_interval': 200,
'cuda': True,
'seed': 313,
'weight_decay': 0,
'optimizer': ["sgd"],
'seq_length': 35,
'batch_size': 20,
'num_epochs': 20,
'lr': 1,
'lr_decay_start': 20,
'lr_decay': 0.8,
'clip_norm': 5,
'save_model': True,
'model_path': ['lstm_model.pt'],
}
# Load dataset
corpus = Corpus()
train_data = corpus.get_data(os.path.join(data_path, 'train.txt'), parameters['batch_size'])
valid_data = corpus.get_data(os.path.join(data_path, 'valid.txt'), parameters['batch_size'])
test_data = corpus.get_data(os.path.join(data_path, 'test.txt'), parameters['batch_size'])
parameters['model']['vocab_size'] = len(corpus.dictionary)
all_results = []
all_parameters = create_parameter_grid(parameters)
for index, params in enumerate(all_parameters):
LOGGER.info("\nTuning %s/%s", index+1, len(all_parameters))
LOGGER.info("Parameters: %s", json.dumps(params, indent=4, default=str))
params['model_path'] = f'{index}_{params["model_path"]}'
start = time.time()
lm_trainer = LanguageModelTrainer(train_data, valid_data, test_data, params)
if sacred_mongo:
from sacred_experiment import start_sacred_experiment
start_sacred_experiment(lm_trainer, params, sacred_mongo)
else:
lm_trainer.train_model()
LOGGER.info("Results: %s", json.dumps(lm_trainer.get_results(), indent=4, default=str))
LOGGER.info("Training took: %ss", time.time()-start)
all_results.append({"parameters": params, "results": lm_trainer.get_results()})
return all_results
if args.data:
all_results = hyperparameter_tune_language_model(args.data, sacred_mongo=args.sacred_mongo)
else:
#all_results = hyperparameter_tune_language_model('data/wikitext-2/', sacred_mongo=args.sacred_mongo)
all_results = hyperparameter_tune_language_model('data/penn/', sacred_mongo=args.sacred_mongo)
print(all_results)