-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
52ccf91
commit 40be63d
Showing
1 changed file
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.append('.') | ||
import logging | ||
import copy | ||
from config import CONFIG | ||
from pprint import pprint, pformat | ||
|
||
import logging | ||
from pprint import pprint, pformat | ||
logging.basicConfig(format="%(levelname)-8s:%(filename)s.%(funcName)20s >> %(message)s") | ||
log = logging.getLogger(__name__) | ||
log.setLevel(logging.INFO) | ||
|
||
from ..debug import memory_consumed | ||
from ..utilz import ListTable, Averager, tqdm, init_hidden | ||
from ..utilz import are_weights_same | ||
|
||
|
||
from .trainer import EpochAverager, FLAGS | ||
from .trainer import Trainer, Tester | ||
|
||
|
||
import torch | ||
|
||
from torch import optim, nn | ||
from collections import namedtuple | ||
|
||
from nltk.corpus import stopwords | ||
|
||
class Trainer(Trainer): | ||
def __init__(self, name, | ||
config, | ||
model, | ||
|
||
feed, | ||
optimizer, | ||
|
||
loss_function, | ||
directory, | ||
|
||
teacher_forcing_ratio=0.5, | ||
|
||
epochs=1000, | ||
checkpoint=1, | ||
do_every_checkpoint=None, | ||
|
||
*args, | ||
**kwargs | ||
|
||
): | ||
|
||
self.name = name | ||
self.config = config | ||
self.ROOT_DIR = directory | ||
|
||
self.log = logging.getLogger('{}.{}.{}'.format(__name__, self.__class__.__name__, self.name)) | ||
|
||
self.model = model | ||
self.feed = feed | ||
|
||
self.teacher_forcing_ratio = teacher_forcing_ratio | ||
|
||
self.epochs = epochs | ||
self.checkpoint = min(checkpoint, epochs) | ||
|
||
self.do_every_checkpoint = do_every_checkpoint if not do_every_checkpoint == None else lambda x: FLAGS.CONTINUE_TRAINING | ||
|
||
self.loss_function = loss_function if loss_function else nn.NLLLoss() | ||
self.optimizer = optimizer if optimizer else optim.SGD(self.model.parameters(), | ||
lr=0.01, momentum=0.1) | ||
|
||
self.__build_stats() | ||
self.best_model = (0, self.model.state_dict()) | ||
|
||
if self.config.CONFIG.cuda: | ||
self.model.cuda() | ||
|
||
def train(self): | ||
for epoch in range(self.epochs): | ||
self.log.critical('memory consumed : {}'.format(memory_consumed())) | ||
|
||
if epoch % max(1, (self.checkpoint - 1)) == 0: | ||
if self.do_every_checkpoint(epoch) == FLAGS.STOP_TRAINING: | ||
self.log.info('loss trend suggests to stop training') | ||
return | ||
|
||
self.model.train() | ||
for j in tqdm(range(self.feed.num_batch), desc='Trainer.{}'.format(self.name)): | ||
log.debug('{}th batch'.format(j)) | ||
|
||
input_ = self.feed.next_batch() | ||
idxs, inputs, targets = input_ | ||
sequence = inputs[0].transpose(0,1) | ||
_, batch_size = sequence.size() | ||
|
||
state = self.model.initial_hidden(batch_size) | ||
loss = 0 | ||
for ti in range(sequence.size(0) - 1): | ||
output = self.model(sequence[ti], state) | ||
loss += self.loss_function(ti, output, input_) | ||
output, state = output | ||
output = output.max(1)[1] | ||
|
||
loss.backward() | ||
self.train_loss.cache(loss.data.item()) | ||
#nn.utils.clip_grad_norm(self.encoder_model.parameters(), Config.max_grad_norm) | ||
#nn.utils.clip_grad_norm(self.decoder_model.parameters(), Config.max_grad_norm) | ||
self.optimizer.step() | ||
|
||
|
||
self.log.info('-- {} -- loss: {}\n'.format(epoch, self.train_loss.epoch_cache)) | ||
self.train_loss.clear_cache() | ||
|
||
for m in self.metrics: | ||
m.write_to_file() | ||
|
||
return True | ||
|
||
|
||
class Tester(Tester): | ||
def __init__(self, name, | ||
config, | ||
model, | ||
feed, | ||
loss_function, | ||
accuracy_function, | ||
directory, | ||
f1score_function=None, | ||
best_model=None, | ||
predictor=None, | ||
save_model_weights=True, | ||
): | ||
|
||
self.name = name | ||
self.config = config | ||
self.ROOT_DIR = directory | ||
|
||
self.log = logging.getLogger('{}.{}.{}'.format(__name__, | ||
self.__class__.__name__, | ||
self.name)) | ||
|
||
self.model = model | ||
|
||
self.feed = feed | ||
|
||
self.predictor = predictor | ||
|
||
self.accuracy_function = accuracy_function if accuracy_function else self._default_accuracy_function | ||
self.loss_function = loss_function if loss_function else nn.NLLLoss() | ||
self.f1score_function = f1score_function | ||
|
||
|
||
self.__build_stats() | ||
|
||
self.save_model_weights = save_model_weights | ||
self.best_model = (0.000001, self.model.cpu().state_dict()) | ||
try: | ||
f = '{}/{}_best_model_accuracy.txt'.format(self.ROOT_DIR, self.name) | ||
if os.path.isfile(f): | ||
self.best_model = (float(open(f).read().strip()), self.model.cpu().state_dict()) | ||
self.log.info('loaded last best accuracy: {}'.format(self.best_model[0])) | ||
except: | ||
log.exception('no last best model') | ||
|
||
|
||
self.best_model_criteria = self.accuracy | ||
self.save_best_model() | ||
|
||
if self.config.CONFIG.cuda: | ||
self.model.cuda() | ||
|
||
|
||
def do_every_checkpoint(self, epoch, early_stopping=True): | ||
|
||
self.model.eval() | ||
for j in tqdm(range(self.feed.num_batch)): | ||
input_ = self.feed.next_batch() | ||
idxs, inputs, targets = input_ | ||
sequence = inputs[0].transpose(0,1) | ||
outputs = [] | ||
loss, accuracy = 0, 0 | ||
for ti in range(sequence.size(0) - 1): | ||
output, state = self.model(sequence[ti], state) | ||
loss += self.loss_function(ti, output, input_) | ||
accuracy += self.accuracy_function(ti, decoder_output, input_) | ||
outputs.append(output) | ||
|
||
self.test_loss.cache(loss.item()) | ||
if ti == 0: ti = 1 | ||
self.accuracy.cache(accuracy.item()/ti) | ||
print('====', self.test_loss, self.test_accuracy) | ||
|
||
self.log.info('= {} =loss:{}'.format(epoch, self.test_loss.epoch_cache)) | ||
self.log.info('- {} -accuracy:{}'.format(epoch, self.accuracy.epoch_cache)) | ||
|
||
if self.best_model[0] < self.accuracy.epoch_cache.avg: | ||
self.log.info('beat best model...') | ||
last_acc = self.best_model[0] | ||
self.best_model = (self.accuracy.epoch_cache.avg, | ||
(self.encoder_model.state_dict(), | ||
self.decoder_model.state_dict()) | ||
) | ||
self.save_best_model() | ||
|
||
if self.config.CONFIG.cuda: | ||
self.model.cuda() | ||
|
||
if self.predictor and self.best_model[0] > 0.75: | ||
log.info('accuracy is greater than 0.75...') | ||
if ((self.best_model[0] >= self.config.CONFIG.ACCURACY_THRESHOLD and (5 * (self.best_model[0] - last_acc) > self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD)) | ||
or (self.best_model[0] - last_acc) > self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD): | ||
|
||
self.predictor.run_prediction(self.accuracy.epoch_cache.avg) | ||
|
||
|
||
self.test_loss.clear_cache() | ||
self.accuracy.clear_cache() | ||
|
||
for m in self.metrics: | ||
m.write_to_file() | ||
|
||
if early_stopping: | ||
return self.loss_trend() | ||
|
||
|
||
class Predictor(object): | ||
def __init__(self, name, model, | ||
feed, | ||
repr_function, | ||
directory, | ||
|
||
*args, **kwargs): | ||
self.name = name | ||
self.ROOT_DIR = directory | ||
|
||
self.model = model | ||
self.repr_function = repr_function | ||
|
||
self.log = logging.getLogger('{}.{}.{}'.format(__name__, self.__class__.__name__, self.name)) | ||
|
||
self.feed = feed | ||
|
||
def predict(self, batch_index=0, max_decoder_len=10): | ||
log.debug('batch_index: {}'.format(batch_index)) | ||
idxs, i, *__ = self.feed.nth_batch(batch_index) | ||
self.model.eval() | ||
decoder_outputs = [] | ||
input_ = self.feed.next_batch() | ||
idxs, inputs, targets = input_ | ||
encoder_output = self.encoder_model(input_) | ||
loss = 0 | ||
|
||
results = ListTable() | ||
decoder_input = self.decoder_model.initial_input(input_, encoder_output) | ||
for ti in range(max_decoder_len): | ||
decoder_output = self.decoder_model(input_, encoder_output, decoder_input) | ||
decoder_output, decoder_input = self.process_output(ti, decoder_output, input_) | ||
decoder_outputs.append(decoder_output) | ||
|
||
decoder_outputs = torch.stack(decoder_outputs) | ||
result = self.repr_function(decoder_outputs, input_) | ||
results.extend(result) | ||
return decoder_outputs, results | ||
|
||
def run_prediction(self, accuracy): | ||
dump = open('{}/results/{}_{:0.4f}.csv'.format(self.ROOT_DIR, self.name, accuracy), 'w') | ||
self.log.info('on {}th eon'.format(accuracy)) | ||
results = ListTable() | ||
for ri in tqdm( | ||
range(self.feed.num_batch), | ||
desc='running prediction at accuracy: {:0.4f}'.format(accuracy)): | ||
output, _results = self.predict(ri) | ||
results.extend(_results) | ||
dump.write(repr(results)) | ||
dump.close() |