Skip to content

Commit

Permalink
trainer object is now separaed into trainer and tester to enable more
Browse files Browse the repository at this point in the history
granular control
  • Loading branch information
vanangamudi committed Jul 18, 2018
1 parent 8f5b6df commit 7c147e1
Showing 1 changed file with 155 additions and 73 deletions.
228 changes: 155 additions & 73 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import logging
import copy
from config import CONFIG
Expand All @@ -20,7 +21,6 @@

from nltk.corpus import stopwords

Feeder = namedtuple('Feeder', ['train', 'test'])
class FLAGS:
CONTINUE_TRAINING = 0
STOP_TRAINING = 1
Expand All @@ -40,99 +40,156 @@ def clear_cache(self):

class Trainer(object):
def __init__(self, name,
config,
model,
feeder,
feed,
optimizer,
loss_function,
accuracy_function,
directory,
f1score_function=None,
epochs=10000,
epochs=1000,
checkpoint=1,
*args, **kwargs):
do_every_checkpoint=None
):

self.name = name
self.config = config
self.ROOT_DIR = directory
assert model != None
self.model = model
self.__build_feeder(feeder, *args, **kwargs)

self.log = logging.getLogger('{}.{}'.format(__name__, self.name))

self.model = model
self.feed = feed

self.epochs = epochs
self.checkpoint = checkpoint

self.accuracy_function = accuracy_function if accuracy_function else self._default_accuracy_function
self.do_every_checkpoint = do_every_checkpoint if not do_every_checkpoint == None else lambda x: FLAGS.CONTINUE_TRAINING

self.loss_function = loss_function if loss_function else nn.NLLLoss()
self.f1score_function = f1score_function

self.optimizer = optimizer if optimizer else optim.SGD(self.model.parameters(),
lr=0.01, momentum=0.1)

self.__build_stats()
self.best_model = (0, self.model.state_dict())
self.best_model_criteria = self.accuracy
self.save_best_model()

if self.config.CONFIG.cuda:
self.model.cuda()

def __build_stats(self):

# necessary metrics
self.train_loss = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'train_loss'))
self.test_loss = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'test_loss'))
self.accuracy = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'accuracy'))

# optional metrics
self.tp = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'tp'))
self.fp = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'fp'))
self.fn = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'fn'))
self.tn = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'tn'))

self.precision = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'precision'))
self.recall = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'recall'))
self.f1score = EpochAverager(filename = '{}/results/{}.{}'.format(self.ROOT_DIR, 'metrics', 'f1score'))

self.metrics = [self.train_loss, self.test_loss, self.accuracy, self.precision, self.recall, self.f1score]


def __build_feeder(self, feeder, *args, **kwargs):
assert feeder is not None, 'feeder is None, fatal error'
self.feeder = feeder

def save_best_model(self):
log.info('saving the last best model with accuracy {}...'.format(self.best_model[0]))
torch.save(self.best_model[1], '{}/weights/{:0.4f}.{}'.format(self.ROOT_DIR, self.best_model[0], 'pth'))
torch.save(self.best_model[1], '{}/weights/{}.{}'.format(self.ROOT_DIR, self.name, 'pth'))
self.train_loss = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'train_loss'))
self.metrics = [self.train_loss]

def train(self):
for epoch in range(self.epochs):
log.critical('memory consumed : {}'.format(memory_consumed()))
self.log.critical('memory consumed : {}'.format(memory_consumed()))

if self.do_every_checkpoint(epoch) == FLAGS.STOP_TRAINING:
log.info('loss trend suggests to stop training')
return
if epoch % max(1, (self.checkpoint - 1)) == 0:
if self.do_every_checkpoint(epoch) == FLAGS.STOP_TRAINING:
self.log.info('loss trend suggests to stop training')
return

self.model.train()
for j in tqdm(range(self.feeder.train.num_batch)):
log.debug('{}th batch'.format(j))
for j in tqdm(range(self.feed.num_batch)):
self.log.debug('{}th batch'.format(j))
self.optimizer.zero_grad()
input_ = self.feeder.train.next_batch()
input_ = self.feed.next_batch()
output = self.model(input_)
loss = self.loss_function(output, input_)
self.train_loss.cache(loss.data.item())
loss.backward()
self.optimizer.step()


log.info('-- {} -- loss: {}'.format(epoch, self.train_loss.epoch_cache))
self.train_loss.clear_cache()
self.log.info('-- {} -- loss: {}'.format(epoch, self.train_loss.epoch_cache))
self.train_loss.clear_cache()

for m in self.metrics:
m.write_to_file()

return True

class Tester(object):
def __init__(self, name,
config,
model,
feed,
loss_function,
accuracy_function,
directory,
f1score_function=None,
best_model=None,
predictor=None,
save_model_weights=True,
):

self.name = name
self.config = config
self.ROOT_DIR = directory

self.log = logging.getLogger('{}.{}'.format(__name__, self.name))

self.model = model

self.feed = feed

self.predictor = predictor

self.accuracy_function = accuracy_function if accuracy_function else self._default_accuracy_function
self.loss_function = loss_function if loss_function else nn.NLLLoss()
self.f1score_function = f1score_function


self.__build_stats()

self.save_model_weights = save_model_weights
self.best_model = (0.000001, self.model.cpu().state_dict())
try:
f = '{}/{}_best_model_accuracy.txt'.format(self.ROOT_DIR, self.name)
if os.path.isfile(f):
self.best_model = (float(open(f).read().strip()), self.model.cpu().state_dict())
log.info('loaded last best accuracy: {}'.format(self.best_model[0]))
except:
log.exception('no last best model')


self.best_model_criteria = self.accuracy
self.save_best_model()

if self.config.CONFIG.cuda:
self.model.cuda()

def __build_stats(self):

# necessary metrics
self.test_loss = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'test_loss'))
self.accuracy = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'accuracy'))

# optional metrics
self.tp = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'tp'))
self.fp = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'fp'))
self.fn = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'fn'))
self.tn = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'tn'))

self.precision = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'precision'))
self.recall = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'recall'))
self.f1score = EpochAverager(filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'f1score'))

self.metrics = [self.test_loss, self.accuracy, self.precision, self.recall, self.f1score]

def save_best_model(self):
with open('{}/{}_best_model_accuracy.txt'.format(self.ROOT_DIR, self.name), 'w') as f:
f.write(str(self.best_model[0]))

if self.save_model_weights:
self.log.info('saving the last best model with accuracy {}...'.format(self.best_model[0]))
torch.save(self.best_model[1], '{}/weights/{:0.4f}.{}'.format(self.ROOT_DIR, self.best_model[0], 'pth'))
torch.save(self.best_model[1], '{}/weights/{}.{}'.format(self.ROOT_DIR, self.name, 'pth'))

def do_every_checkpoint(self, epoch, early_stopping=True):
if epoch % self.checkpoint != 0:
return

self.model.eval()
for j in tqdm(range(self.feeder.test.num_batch)):
input_ = self.feeder.train.next_batch()
for j in tqdm(range(self.feed.num_batch)):
input_ = self.feed.next_batch()
output = self.model(input_)

loss = self.loss_function(output, input_)
Expand All @@ -151,21 +208,34 @@ def do_every_checkpoint(self, epoch, early_stopping=True):
self.recall.cache(recall)
self.f1score.cache(f1score)

log.info('-- {} -- loss: {}, accuracy: {}'.format(epoch, self.test_loss.epoch_cache, self.accuracy.epoch_cache))
self.log.info('-- {} -- loss: {}'.format(epoch, self.test_loss.epoch_cache))
self.log.info('-- {} -- accuracy: {}'.format(epoch, self.accuracy.epoch_cache))
if self.f1score_function:
log.info('-- {} -- tp: {}'.format(epoch, sum(self.tp.epoch_cache)))
log.info('-- {} -- fn: {}'.format(epoch, sum(self.fn.epoch_cache)))
log.info('-- {} -- fp: {}'.format(epoch, sum(self.fp.epoch_cache)))
log.info('-- {} -- tn: {}'.format(epoch, sum(self.tn.epoch_cache)))
self.log.info('-- {} -- tp: {}'.format(epoch, sum(self.tp.epoch_cache)))
self.log.info('-- {} -- fn: {}'.format(epoch, sum(self.fn.epoch_cache)))
self.log.info('-- {} -- fp: {}'.format(epoch, sum(self.fp.epoch_cache)))
self.log.info('-- {} -- tn: {}'.format(epoch, sum(self.tn.epoch_cache)))

log.info('-- {} -- precision: {}'.format(epoch, self.precision.epoch_cache))
log.info('-- {} -- recall: {}'.format(epoch, self.recall.epoch_cache))
log.info('-- {} -- f1score: {}'.format(epoch, self.f1score.epoch_cache))
self.log.info('-- {} -- precision: {}'.format(epoch, self.precision.epoch_cache))
self.log.info('-- {} -- recall: {}'.format(epoch, self.recall.epoch_cache))
self.log.info('-- {} -- f1score: {}'.format(epoch, self.f1score.epoch_cache))

if self.best_model[0] < self.accuracy.epoch_cache.avg:
log.info('beat best model...')
self.best_model = (self.accuracy.epoch_cache.avg, self.model.state_dict())
self.log.info('beat best model...')
last_acc = self.best_model[0]
self.best_model = (self.accuracy.epoch_cache.avg, self.model.cpu().state_dict())
self.save_best_model()

if self.config.CONFIG.cuda:
self.model.cuda()

if self.predictor and self.best_model[0] > 0.75:
log.info('accuracy is greater than 0.75...')
if ((self.best_model[0] > self.config.CONFIG.ACCURACY_THRESHOLD and (5 * (self.best_model[0] - last_acc) > self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD))
or (self.best_model[0] - last_acc) > self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD):

self.predictor.run_prediction(self.accuracy.epoch_cache.avg)


self.test_loss.clear_cache()
self.accuracy.clear_cache()
Expand Down Expand Up @@ -200,26 +270,38 @@ def loss_trend(self, total_count=10):
def _default_accuracy_function(self):
return -1


class Predictor(object):
def __init__(self, model=None,
feed = None,
repr_function = None,
def __init__(self, name, model,
feed,
repr_function,
directory,
*args, **kwargs):

self.name = name
self.model = model
self.__build_feed(feed, *args, **kwargs)
self.ROOT_DIR = directory

self.log = logging.getLogger('{}.{}'.format(__name__, self.name))

self.repr_function = repr_function

def __build_feed(self, feed, *args, **kwargs):
assert feed is not None, 'feed is None, fatal error'
self.feed = feed

def predict(self, batch_index=0):
log.debug('batch_index: {}'.format(batch_index))
self.log.debug('batch_index: {}'.format(batch_index))
input_ = self.feed.nth_batch(batch_index)
self.model.eval()
output = self.model(input_)
results = ListTable()
results.extend( self.repr_function(output, input_) )
output_ = output
return output_, results

def run_prediction(self, accuracy):
dump = open('{}/results/{}_{:0.4f}.csv'.format(self.ROOT_DIR, self.name, accuracy), 'w')
self.log.info('on {}th eon'.format(accuracy))
results = ListTable()
for ri in tqdm(range(self.feed.num_batch)):
output, _results = self.predict(ri)
results.extend(_results)
dump.write(repr(results))
dump.close()

0 comments on commit 7c147e1

Please sign in to comment.