Skip to content

Commit

Permalink
merge macnet-babi and tamil-lm setup
Browse files Browse the repository at this point in the history
  • Loading branch information
vanangamudi committed Jan 9, 2019
2 parents af7a6f7 + a3737f3 commit eeb561f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 31 deletions.
68 changes: 41 additions & 27 deletions trainer/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .trainer import EpochAverager, FLAGS
from .trainer import Trainer, Tester
from ..utilz import Var, LongVar


import torch
Expand Down Expand Up @@ -88,17 +89,16 @@ def train(self):

self.model.train()
for j in tqdm(range(self.feed.num_batch), desc='Trainer.{}'.format(self.name)):
log.debug('{}th batch'.format(j))

input_ = self.feed.next_batch()
idxs, inputs, targets = input_
sequence = inputs[0].transpose(0,1)
_, batch_size = sequence.size()

state = self.model.initial_hidden(batch_size)
loss = 0
for ti in range(sequence.size(0) - 1):
output = self.model(sequence[ti], state)
output = sequence[0]
for ti in range(1, sequence.size(0) - 1):
output = self.model(output, state)
loss += self.loss_function(ti, output, input_)
output, state = output
output = output.max(1)[1]
Expand Down Expand Up @@ -175,22 +175,28 @@ def __init__(self, name,
def do_every_checkpoint(self, epoch, early_stopping=True):

self.model.eval()
for j in tqdm(range(self.feed.num_batch)):
for j in tqdm(range(self.feed.num_batch), desc='Tester.{}'.format(self.name)):
input_ = self.feed.next_batch()
idxs, inputs, targets = input_
sequence = inputs[0].transpose(0,1)
_, batch_size = sequence.size()

state = self.model.initial_hidden(batch_size)
loss, accuracy = Var(self.config, [0]), Var(self.config, [0])
output = sequence[0]
outputs = []
loss, accuracy = 0, 0
for ti in range(sequence.size(0) - 1):
output, state = self.model(sequence[ti], state)
for ti in range(1, sequence.size(0) - 1):
output = self.model(output, state)
loss += self.loss_function(ti, output, input_)
accuracy += self.accuracy_function(ti, decoder_output, input_)
accuracy += self.accuracy_function(ti, output, input_)
output, state = output
output = output.max(1)[1]
outputs.append(output)

self.test_loss.cache(loss.item())
if ti == 0: ti = 1
self.accuracy.cache(accuracy.item()/ti)
print('====', self.test_loss, self.test_accuracy)
#print('====', self.test_loss, self.accuracy)

self.log.info('= {} =loss:{}'.format(epoch, self.test_loss.epoch_cache))
self.log.info('- {} -accuracy:{}'.format(epoch, self.accuracy.epoch_cache))
Expand All @@ -199,8 +205,8 @@ def do_every_checkpoint(self, epoch, early_stopping=True):
self.log.info('beat best model...')
last_acc = self.best_model[0]
self.best_model = (self.accuracy.epoch_cache.avg,
(self.encoder_model.state_dict(),
self.decoder_model.state_dict())
(self.model.state_dict())

)
self.save_best_model()

Expand All @@ -209,8 +215,13 @@ def do_every_checkpoint(self, epoch, early_stopping=True):

if self.predictor and self.best_model[0] > 0.75:
log.info('accuracy is greater than 0.75...')
if ((self.best_model[0] >= self.config.CONFIG.ACCURACY_THRESHOLD and (5 * (self.best_model[0] - last_acc) > self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD))
or (self.best_model[0] - last_acc) > self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD):
if ((
self.best_model[0] >= self.config.CONFIG.ACCURACY_THRESHOLD and
( 5*(self.best_model[0] - last_acc) >
self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD))
or (self.best_model[0] - last_acc)
> self.config.CONFIG.ACCURACY_IMPROVEMENT_THRESHOLD
):

self.predictor.run_prediction(self.accuracy.epoch_cache.avg)

Expand Down Expand Up @@ -245,24 +256,27 @@ def __init__(self, name, model,
def predict(self, batch_index=0, max_decoder_len=10):
log.debug('batch_index: {}'.format(batch_index))
idxs, i, *__ = self.feed.nth_batch(batch_index)
outputs = []
self.model.eval()
decoder_outputs = []
input_ = self.feed.next_batch()
idxs, inputs, targets = input_
encoder_output = self.encoder_model(input_)
sequence = inputs[0].transpose(0,1)
_, batch_size = sequence.size()

state = self.model.initial_hidden(batch_size)
loss = 0

output = sequence[0]
for ti in range(1, sequence.size(0) - 1):
output = self.model(output, state)
output, state = output
output = output.max(1)[1]
outputs.append(output)

results = ListTable()
decoder_input = self.decoder_model.initial_input(input_, encoder_output)
for ti in range(max_decoder_len):
decoder_output = self.decoder_model(input_, encoder_output, decoder_input)
decoder_output, decoder_input = self.process_output(ti, decoder_output, input_)
decoder_outputs.append(decoder_output)

decoder_outputs = torch.stack(decoder_outputs)
result = self.repr_function(decoder_outputs, input_)
outputs = torch.stack(outputs)
result = self.repr_function(outputs, input_)
results.extend(result)
return decoder_outputs, results
return outputs, results

def run_prediction(self, accuracy):
dump = open('{}/results/{}_{:0.4f}.csv'.format(self.ROOT_DIR, self.name, accuracy), 'w')
Expand Down
8 changes: 4 additions & 4 deletions utilz.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def are_weights_same(model1, model2):
return False
return True

def LongVar(array, requires_grad=False):
return Var(array, requires_grad).long()
def LongVar(config, array, requires_grad=False):
return Var(config, array, requires_grad).long()

def Var(array, requires_grad=False):
def Var(config, array, requires_grad=False):
ret = Variable(torch.Tensor(array), requires_grad=requires_grad)
if config.CONFIG.cuda:
ret = ret.cuda()
Expand All @@ -160,7 +160,7 @@ def Var(array, requires_grad=False):

def init_hidden(batch_size, cell):
layers = 1
if isinstance(cell, (nn.LSTM, nn.GRU, nn.LSTMCell, nn.GRUCell)):
if isinstance(cell, (nn.LSTMCell, nn.GRUCell)):
layers = cell.num_layers
if cell.bidirectional:
layers = layers * 2
Expand Down

0 comments on commit eeb561f

Please sign in to comment.