Skip to content

Commit

Permalink
MultiTrainer class is removed and some more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vanangamudi committed Aug 8, 2018
1 parent 2afcf73 commit 410bb16
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 63 deletions.
2 changes: 1 addition & 1 deletion trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .trainer import Trainer, MultiTrainer, Tester, Predictor
from .trainer import Trainer, Tester, Predictor
82 changes: 20 additions & 62 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,50 +38,7 @@ def cache(self, a):
def clear_cache(self):
super(EpochAverager, self).append(self.epoch_cache.avg)
self.epoch_cache.empty();

class MultiTrainer(object):
def __init__(self, name, config, trainers, testers, next_trainer_func=None, primary_trainer=None):
self.name = name
self.config = config
self.trainers = trainers
self.testers = testers

self.log = logging.getLogger('{}.{}.{}'.format(__name__, self.__class__.__name__, self.name))

self.primary_trainer = None
if primary_trainer:
self.primary_trainer = self.trainers[primary_trainer]

if next_trainer_func:
self.get_next_trainer = next_trainer_func

def train(self):
if self.primary_trainer:
self.log.info('running primary trainer: {}'.format(self.primary_trainer.name))
self.primary_trainer.train()

for i in range(self.primary_trainer.checkpoint * 5):
if not self.get_next_trainer().train():
raise Exception
else:
for name, trainer in self.trainers.items():
if not trainer.train():
raise Exception

return True

def get_next_trainer(self):

accuracies = sorted(
[(v.accuracy[-1], k) for k,v in self.testers.items()],
key=lambda x: x[0]
)
pprint(accuracies)

self.log.info('{:0.4f} ===> {}'.format(*accuracies[0]))

return self.trainers[accuracies[0][1]]



class Trainer(object):
def __init__(self, name,
Expand Down Expand Up @@ -206,18 +163,19 @@ def __init__(self, name,
def __build_stats(self):

# necessary metrics
self.test_loss = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'test_loss'))
self.accuracy = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'accuracy'))
self.mfile_prefix = '{}/results/metrics/{}'.format(self.ROOT_DIR, self.name)
self.test_loss = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'test_loss'))
self.accuracy = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'accuracy'))

# optional metrics
self.tp = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'tp'))
self.fp = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'fp'))
self.fn = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'fn'))
self.tn = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'tn'))
self.tp = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'tp'))
self.fp = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'fp'))
self.fn = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'fn'))
self.tn = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'tn'))

self.precision = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'precision'))
self.recall = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'recall'))
self.f1score = EpochAverager(self.config, filename = '{}/results/metrics/{}.{}'.format(self.ROOT_DIR, self.name, 'f1score'))
self.precision = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'precision'))
self.recall = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'recall'))
self.f1score = EpochAverager(self.config, filename = '{}.{}'.format(self.mfile_prefix, 'f1score'))

self.metrics = [self.test_loss, self.accuracy, self.precision, self.recall, self.f1score]

Expand Down Expand Up @@ -253,17 +211,17 @@ def do_every_checkpoint(self, epoch, early_stopping=True):
self.recall.cache(recall)
self.f1score.cache(f1score)

self.log.info('={}=loss:{}'.format(epoch, self.test_loss.epoch_cache))
self.log.info('-{}-accuracy:{}'.format(epoch, self.accuracy.epoch_cache))
self.log.info('= {} =loss:{}'.format(epoch, self.test_loss.epoch_cache))
self.log.info('- {} -accuracy:{}'.format(epoch, self.accuracy.epoch_cache))
if self.f1score_function:
self.log.info('-{}-tp:{}'.format(epoch, sum(self.tp.epoch_cache)))
self.log.info('-{}-fn:{}'.format(epoch, sum(self.fn.epoch_cache)))
self.log.info('-{}-fp:{}'.format(epoch, sum(self.fp.epoch_cache)))
self.log.info('-{}-tn:{}'.format(epoch, sum(self.tn.epoch_cache)))
self.log.info('- {} -tp:{}'.format(epoch, sum(self.tp.epoch_cache)))
self.log.info('- {} -fn:{}'.format(epoch, sum(self.fn.epoch_cache)))
self.log.info('- {} -fp:{}'.format(epoch, sum(self.fp.epoch_cache)))
self.log.info('- {} -tn:{}'.format(epoch, sum(self.tn.epoch_cache)))

self.log.info('-{}-precision:{}'.format(epoch, self.precision.epoch_cache))
self.log.info('-{}-recall:{}'.format(epoch, self.recall.epoch_cache))
self.log.info('-{}-f1score:{}\n'.format(epoch, self.f1score.epoch_cache))
self.log.info('- {} -precision:{}'.format(epoch, self.precision.epoch_cache))
self.log.info('- {} -recall:{}'.format(epoch, self.recall.epoch_cache))
self.log.info('- {} -f1score:{}\n'.format(epoch, self.f1score.epoch_cache))

if self.best_model[0] < self.accuracy.epoch_cache.avg:
self.log.info('beat best model...')
Expand Down

0 comments on commit 410bb16

Please sign in to comment.