Skip to content

Commit

Permalink
Moving code from train.py and introducing ModelSaver and ReportManage…
Browse files Browse the repository at this point in the history
…r classes
  • Loading branch information
pltrdy committed May 17, 2018
1 parent 6128d9a commit 2c36bc6
Show file tree
Hide file tree
Showing 12 changed files with 545 additions and 371 deletions.
10 changes: 2 additions & 8 deletions onmt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
""" Main entry point of the ONMT library """
from onmt import decoders
from onmt import encoders
from onmt import inputters
from onmt import models
from onmt import modules
#import onmt.opts
from onmt.trainer import Trainer, Statistics
#from onmt.encoders.transformer import TransformerEncoder
import onmt.model_builder
from onmt.trainer import Trainer
1 change: 1 addition & 0 deletions onmt/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module defining models."""
from onmt.models.model_saver import build_model_saver, ModelSaver
from onmt.models.model import NMTModel
from onmt.models.SRU import check_sru_requirement
CAN_USE_SRU = check_sru_requirement()
Expand Down
70 changes: 70 additions & 0 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch.nn as nn

import onmt.inputters


def build_model_saver(model_opt, opt, model, fields, optim):
model_saver = ModelSaver(opt.save_model,
model,
model_opt,
fields,
optim,
opt.start_checkpoint_at)
return model_saver


class ModelSaverBase(object):
def __init__(self, base_path, model, model_opt, fields, optim, start_checkpoint_at=0):
self.base_path = base_path
self.model = model
self.model_opt = model_opt
self.fields = fields
self.optim = optim
self.start_checkpoint_at = start_checkpoint_at

def maybe_save(self, epoch, valid_stats):
if epoch >= start_checkpoint_at:
self._save(epoch, valid_stats)

def force_save(self, epoch, valid_stats):
self._save(epoch, valid_stats)

def _save(self, epoch, valid_stats):
""" Save a resumable checkpoint.
Args:
epoch (int): epoch number
valid_stats : statistics of last validation run
"""
raise NotImplementedError()


class ModelSaver(ModelSaverBase):
def __init__(self, base_path, model, model_opt, fields, optim, start_checkpoint_at=0):
super(ModelSaver, self).__init__(
base_path, model, model_opt, fields, optim, start_checkpoint_at=0)

def _save(self, epoch, valid_stats):
real_model = (self.model.module
if isinstance(self.model, nn.DataParallel)
else self.model)
real_generator = (real_model.generator.module
if isinstance(real_model.generator, nn.DataParallel)
else real_model.generator)

model_state_dict = real_model.state_dict()
model_state_dict = {k: v for k, v in model_state_dict.items()
if 'generator' not in k}
generator_state_dict = real_generator.state_dict()
checkpoint = {
'model': model_state_dict,
'generator': generator_state_dict,
'vocab': onmt.inputters.save_fields_to_vocab(fields),
'opt': opt,
'epoch': epoch,
'optim': self.optim,
}
torch.save(checkpoint,
'%s_acc_%.2f_ppl_%.2f_e%d.pt'
% (self.base_path, valid_stats.accuracy(),
valid_stats.ppl(), epoch))
4 changes: 2 additions & 2 deletions onmt/modules/copy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import onmt
import onmt.inputters as inputters
from onmt.utils.misc import aeq

from onmt.utils import loss

class CopyGenerator(nn.Module):
"""Generator module that additionally considers copying
Expand Down Expand Up @@ -138,7 +138,7 @@ def __call__(self, scores, align, target):
return loss


class CopyGeneratorLossCompute(onmt.utils.loss.LossComputeBase):
class CopyGeneratorLossCompute(loss.LossComputeBase):
"""
Copy Generator Loss Computation.
"""
Expand Down
165 changes: 75 additions & 90 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,84 +17,26 @@

import onmt.inputters as inputters
from onmt.inputters.inputter import build_dataset_iter, lazily_load_dataset, _load_fields, _collect_report_features
from onmt.utils.loss import make_loss_compute
from onmt.utils.misc import use_gpu
import onmt.utils

class Statistics(object):
"""
Accumulator for loss statistics.
Currently calculates:

* accuracy
* perplexity
* elapsed time
"""
def __init__(self, loss=0, n_words=0, n_correct=0):
self.loss = loss
self.n_words = n_words
self.n_correct = n_correct
self.n_src_words = 0
self.start_time = time.time()

def update(self, stat):
""" update statistics """
self.loss += stat.loss
self.n_words += stat.n_words
self.n_correct += stat.n_correct

def accuracy(self):
""" compute accuracy """
return 100 * (self.n_correct / self.n_words)

def xent(self):
""" compute cross entropy """
return self.loss / self.n_words

def ppl(self):
""" compute perplexity """
return math.exp(min(self.loss / self.n_words, 100))

def elapsed_time(self):
""" compute elapsed time """
return time.time() - self.start_time

def output(self, epoch, batch, n_batches, start):
"""Write out statistics to stdout.
def build_trainer(opt, model, fields, optim, data_type, model_saver=None):
train_loss = onmt.utils.loss.build_loss_compute(
model, fields["tgt"].vocab, opt)
valid_loss = onmt.utils.loss.build_loss_compute(
model, fields["tgt"].vocab, opt, train=False)

Args:
epoch (int): current epoch
batch (int): current batch
n_batch (int): total batches
start (int): start time of epoch.
"""
t = self.elapsed_time()
print(("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; xent: %6.2f; " +
"%3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed") %
(epoch, batch, n_batches,
self.accuracy(),
self.ppl(),
self.xent(),
self.n_src_words / (t + 1e-5),
self.n_words / (t + 1e-5),
time.time() - start))
sys.stdout.flush()

def log(self, prefix, experiment, learning_rate):
""" log statistics """
t = self.elapsed_time()
experiment.add_scalar_value(prefix + "_ppl", self.ppl())
experiment.add_scalar_value(prefix + "_accuracy", self.accuracy())
experiment.add_scalar_value(prefix + "_tgtper", self.n_words / t)
experiment.add_scalar_value(prefix + "_lr", learning_rate)

def log_tensorboard(self, prefix, writer, learning_rate, step):
""" display statistics to tensorboard """
t = self.elapsed_time()
writer.add_scalar(prefix + "/xent", self.xent(), step)
writer.add_scalar(prefix + "/ppl", self.ppl(), step)
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
writer.add_scalar(prefix + "/lr", learning_rate, step)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
norm_method = opt.normalization
grad_accum_count = opt.accum_count

report_manager = onmt.utils.build_report_manager(opt)
trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
trunc_size, shard_size, data_type,
norm_method, grad_accum_count, report_manager,
model_saver=None)
return trainer


class Trainer(object):
Expand All @@ -115,11 +57,14 @@ class Trainer(object):
data_type(string): type of the source input: [text|img|audio]
norm_method(string): normalization methods: [sents|tokens]
grad_accum_count(int): accumulate gradients this many times.
report_manager(:obj:`onmt.utils.ReportMgrBase`):
the object that creates reports, or None
"""

def __init__(self, model, train_loss, valid_loss, optim,
trunc_size=0, shard_size=32, data_type='text',
norm_method="sents", grad_accum_count=1):
norm_method="sents", grad_accum_count=1, report_manager=None,
model_saver=None):
# Basic attributes.
self.model = model
self.train_loss = train_loss
Expand All @@ -130,7 +75,8 @@ def __init__(self, model, train_loss, valid_loss, optim,
self.data_type = data_type
self.norm_method = norm_method
self.grad_accum_count = grad_accum_count
self.progress_step = 0
self.report_manager = report_manager
self.model_saver = model_saver

assert grad_accum_count > 0
if grad_accum_count > 1:
Expand All @@ -141,18 +87,46 @@ def __init__(self, model, train_loss, valid_loss, optim,
# Set model in training mode.
self.model.train()

def train(self, train_iter, epoch, report_func=None):
def train(self, train_iter_fct, valid_iter_fct, start_epoch, end_epoch):

print('\nStart training...')
print(' * number of epochs: %d, starting from Epoch %d' %
(end_epoch + 1 - start_epoch, start_epoch))
# print(' * batch size: %d' % batch_size)
for epoch in range(start_epoch, end_epoch + 1):
print('')

# 1. Train for one epoch on the training set.
train_iter = train_iter_fct()
train_stats = self.train_epoch(train_iter, epoch)
self.report_manager.report_epoch(
self.optim.lr, epoch, train_stats=train_stats)

# 2. Validate on the validation set.
valid_iter = valid_iter_fct()
valid_stats = self.validate(valid_iter)
self.report_manager.report_epoch(
self.optim.lr, epoch, valid_stats=valid_stats)

# 3. Update the learning rate
self.epoch_step(valid_stats.ppl(), epoch)

# 4. Drop a checkpoint if needed.
self.maybe_drop_checkpoint(epoch, valid_stats)

def train_epoch(self, train_iter, epoch):
""" Train next epoch.
Args:
train_iter: training data iterator
epoch(int): the epoch number
report_func(fn): function for logging
Returns:
stats (:obj:`onmt.Statistics`): epoch loss statistics
stats (:obj:`onmt.utils.Statistics`): epoch loss statistics
"""
total_stats = Statistics()
report_stats = Statistics()
total_stats = onmt.utils.Statistics()
report_stats = onmt.utils.Statistics()
self.report_manager.start_time = total_stats.start_time

idx = 0
true_batchs = []
accum = 0
Expand Down Expand Up @@ -184,13 +158,10 @@ def train(self, train_iter, epoch, report_func=None):
true_batchs, total_stats,
report_stats, normalization)

if report_func is not None:
report_stats = report_func(
epoch, idx, num_batches,
self.progress_step,
total_stats.start_time, self.optim._lr,
report_stats)
self.progress_step += 1
report_stats = self.report_training(
epoch, idx, num_batches,
self.optim.lr,
report_stats)

true_batchs = []
accum = 0
Expand All @@ -214,7 +185,7 @@ def validate(self, valid_iter):
# Set model in validating mode.
self.model.eval()

stats = Statistics()
stats = onmt.utils.Statistics()

for batch in valid_iter:
cur_dataset = valid_iter.get_cur_dataset()
Expand Down Expand Up @@ -330,3 +301,17 @@ def _gradient_accumulation(self, true_batchs, total_stats,

if self.grad_accum_count > 1:
self.optim.step()

def report_training(self, epoch, batch, num_batches, learning_rate, report_stats):
if self.report_manager is not None:
return self.report_manager.report_training(
epoch, batch, num_batches, learning_rate, report_stats)

def report_epoch(self, lr, epoch, train_stats=None, valid_stats=None):
if self.report_manager is not None:
return self.report_manager.report_epoch(
lr, epoch, train_stats=None, valid_stats=None)

def maybe_drop_checkpoint(self, epoch, valid_stats):
if self.model_saver is not None:
self.model_saver.maybe_save(epoch, valid_stats)
6 changes: 3 additions & 3 deletions onmt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module defining various utilities."""
from onmt.utils.misc import aeq
from onmt.utils.loss import NMTLossCompute
from onmt.utils.optimizers import Optimizer
from onmt.utils.misc import aeq, use_gpu
from onmt.utils.report_manager import ReportMgr, build_report_manager
from onmt.utils.statistics import Statistics
10 changes: 5 additions & 5 deletions onmt/utils/cnn_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import torch
import torch.nn as nn
import torch.nn.init as init
#import torch.nn.functional as F

from onmt.modules.weight_norm import WeightNormConv2d
import onmt

SCALE_WEIGHT = 0.5 ** 0.5

Expand All @@ -18,11 +17,12 @@ def shape_transform(x):

class GatedConv(nn.Module):
""" Gated convolution for CNN class """

def __init__(self, input_size, width=3, dropout=0.2, nopad=False):
super(GatedConv, self).__init__()
self.conv = WeightNormConv2d(input_size, 2 * input_size,
kernel_size=(width, 1), stride=(1, 1),
padding=(width // 2 * (1 - nopad), 0))
self.conv = onmt.modules.WeightNormConv2d(
input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1),
padding=(width // 2 * (1 - nopad), 0))
init.xavier_uniform(self.conv.weight, gain=(4 * (1 - dropout))**0.5)
self.dropout = nn.Dropout(dropout)

Expand Down
Loading

0 comments on commit 2c36bc6

Please sign in to comment.