Skip to content

Commit

Permalink
fix from reviews + logger in many places
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Jul 3, 2018
1 parent a117618 commit c75a06a
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 119 deletions.
29 changes: 14 additions & 15 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
# if len(src_vocab_path) > 0:
if src_vocab_path:
src_vocab = set([])
print('Loading source vocab from %s' % src_vocab_path)
if logger:
logger.info('Loading source vocab from %s' % src_vocab_path)
assert os.path.exists(src_vocab_path), \
'src vocab %s not found!' % src_vocab_path
with open(src_vocab_path) as f:
Expand All @@ -315,7 +316,8 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
# if len(tgt_vocab_path) > 0:
if tgt_vocab_path:
tgt_vocab = set([])
print('Loading target vocab from %s' % tgt_vocab_path)
if logger:
logger.info('Loading target vocab from %s' % tgt_vocab_path)
assert os.path.exists(tgt_vocab_path), \
'tgt vocab %s not found!' % tgt_vocab_path
with open(tgt_vocab_path) as f:
Expand Down Expand Up @@ -509,7 +511,7 @@ def batch_size_fn(new, count, sofar):
device, is_train)


def lazily_load_dataset(corpus_type, opt):
def lazily_load_dataset(corpus_type, opt, logger):
"""
Dataset generator. Don't do extra stuff here, like printing,
because they will be postponed to the first loading time.
Expand All @@ -523,8 +525,8 @@ def lazily_load_dataset(corpus_type, opt):

def _lazy_dataset_loader(pt_file, corpus_type):
dataset = torch.load(pt_file)
print('Loading %s dataset from %s, number of examples: %d' %
(corpus_type, pt_file, len(dataset)))
logger.info('Loading %s dataset from %s, number of examples: %d' %
(corpus_type, pt_file, len(dataset)))
return dataset

# Sort the glob output by file name (by increasing indexes).
Expand All @@ -538,9 +540,9 @@ def _lazy_dataset_loader(pt_file, corpus_type):
yield _lazy_dataset_loader(pt, corpus_type)


def _load_fields(dataset, data_type, opt, checkpoint):
def _load_fields(dataset, data_type, opt, checkpoint, logger):
if checkpoint is not None:
print('Loading vocab from checkpoint at %s.' % opt.train_from)
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = load_fields_from_vocab(
checkpoint['vocab'], data_type)
else:
Expand All @@ -550,11 +552,11 @@ def _load_fields(dataset, data_type, opt, checkpoint):
if k in dataset.examples[0].__dict__])

if data_type == 'text':
print(' * vocabulary size. source = %d; target = %d' %
(len(fields['src'].vocab), len(fields['tgt'].vocab)))
logger.info(' * vocabulary size. source = %d; target = %d' %
(len(fields['src'].vocab), len(fields['tgt'].vocab)))
else:
print(' * vocabulary size. target = %d' %
(len(fields['tgt'].vocab)))
logger.info(' * vocabulary size. target = %d' %
(len(fields['tgt'].vocab)))

return fields

Expand All @@ -563,7 +565,4 @@ def _collect_report_features(fields):
src_features = collect_features(fields, side='src')
tgt_features = collect_features(fields, side='tgt')

for j, feat in enumerate(src_features):
print(' * src feature %d size = %d' % (j, len(fields[feat].vocab)))
for j, feat in enumerate(tgt_features):
print(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab)))
return src_features, tgt_features
8 changes: 3 additions & 5 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
This file is for models creation, which consults options
and creates each encoder and decoder accordingly.
"""
from __future__ import print_function

import torch
import torch.nn as nn
Expand Down Expand Up @@ -237,11 +236,10 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None):
return model


def build_model(model_opt, opt, fields, checkpoint):
def build_model(model_opt, opt, fields, checkpoint, logger):
""" Build the Model """
print('Building model...')
logger.info('Building model...')
model = build_base_model(model_opt, fields,
use_gpu(opt), checkpoint)
print(model)

logger.info(model)
return model
4 changes: 3 additions & 1 deletion onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def train_opts(parser):
help="Rank the current gpu device.")
group.add_argument('-gpu_backend', default='nccl', nargs='+', type=str,
help="Type of torch distributed backend")
group.add_argument('-gpu_verbose', default=0, type=int,
group.add_argument('-gpu_verbose_level', default=0, type=int,
help="Gives more info on each process per GPU.")

group.add_argument('-seed', type=int, default=-1,
Expand Down Expand Up @@ -305,6 +305,8 @@ def train_opts(parser):
uses more memory.""")
group.add_argument('-train_steps', type=int, default=100000,
help='Number of training steps')
group.add_argument('-epochs', type=int, default=0,
help='Deprecated epochs see train_steps')
group.add_argument('-optim', default='sgd',
choices=['sgd', 'adagrad', 'adadelta', 'adam',
'sparseadam'],
Expand Down
10 changes: 6 additions & 4 deletions train_multi.py → onmt/train_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import torch

import onmt.opts as opts
import onmt.utils.multi_utils
from train_single import main as single_main
import onmt.utils.distributed
from onmt.utils.misc import get_logger
from onmt.train_single import main as single_main


def main(opt):
""" Spawns 1 process per GPU """
nb_gpu = len(opt.gpuid)
logger = get_logger(opt.log_file)
mp = torch.multiprocessing.get_context('spawn')

# Create a thread to listen for errors in the child processes.
Expand All @@ -30,7 +32,7 @@ def main(opt):
procs.append(mp.Process(target=run, args=(
opt, error_queue, ), daemon=True))
procs[i].start()
print(" Starting process pid: %d " % procs[i].pid)
logger.info(" Starting process pid: %d " % procs[i].pid)
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()
Expand All @@ -39,7 +41,7 @@ def main(opt):
def run(opt, error_queue):
""" run process """
try:
opt.gpu_rank = onmt.utils.multi_utils.multi_init(opt)
opt.gpu_rank = onmt.utils.distributed.multi_init(opt)
single_main(opt)
except KeyboardInterrupt:
pass # killed by parent, do nothing
Expand Down
42 changes: 26 additions & 16 deletions train_single.py → onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""
Training on a single process
"""
from __future__ import print_function
from __future__ import division

import argparse
Expand All @@ -18,6 +17,7 @@
from onmt.utils.optimizers import build_optim
from onmt.trainer import build_trainer
from onmt.models import build_model_saver
from onmt.utils.misc import get_logger


def _check_save_model_path(opt):
Expand All @@ -29,19 +29,17 @@ def _check_save_model_path(opt):

def _tally_parameters(model):
n_params = sum([p.nelement() for p in model.parameters()])
print('* number of parameters: %d' % n_params)
enc = 0
dec = 0
for name, param in model.named_parameters():
if 'encoder' in name:
enc += param.nelement()
elif 'decoder' or 'generator' in name:
dec += param.nelement()
print('encoder: ', enc)
print('decoder: ', dec)
return n_params, enc, dec


def training_opt_postprocessing(opt):
def training_opt_postprocessing(opt, logger):
if opt.word_vec_size != -1:
opt.src_word_vec_size = opt.word_vec_size
opt.tgt_word_vec_size = opt.word_vec_size
Expand All @@ -56,7 +54,7 @@ def training_opt_postprocessing(opt):
raise AssertionError("Using SRU requires -gpuid set.")

if torch.cuda.is_available() and not opt.gpuid:
print("WARNING: You have a CUDA device, should run with -gpuid 0")
logger.info("WARNING: You have a CUDA device, should run with -gpuid")

if opt.gpuid:
torch.cuda.set_device(opt.device_id)
Expand All @@ -72,11 +70,13 @@ def training_opt_postprocessing(opt):


def main(opt):
opt = training_opt_postprocessing(opt)
logger = get_logger(opt.log_file)

opt = training_opt_postprocessing(opt, logger)

# Load checkpoint if we resume from a previous training.
if opt.train_from:
print('Loading checkpoint from %s' % opt.train_from)
logger.info('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)
model_opt = checkpoint['opt']
Expand All @@ -86,18 +86,28 @@ def main(opt):

# Peek the fisrt dataset to determine the data_type.
# (All datasets have the same data_type).
first_dataset = next(lazily_load_dataset("train", opt))
first_dataset = next(lazily_load_dataset("train", opt, logger))
data_type = first_dataset.data_type

# Load fields generated from preprocess phase.
fields = _load_fields(first_dataset, data_type, opt, checkpoint)
fields = _load_fields(first_dataset, data_type, opt, checkpoint, logger)

# Report src/tgt features.
_collect_report_features(fields)

src_features, tgt_features = _collect_report_features(fields)
for j, feat in enumerate(src_features):
logger.info(' * src feature %d size = %d'
% (j, len(fields[feat].vocab)))
for j, feat in enumerate(tgt_features):
logger.info(' * tgt feature %d size = %d'
% (j, len(fields[feat].vocab)))

# Build model.
model = build_model(model_opt, opt, fields, checkpoint)
_tally_parameters(model)
model = build_model(model_opt, opt, fields, checkpoint, logger)
n_params, enc, dec = _tally_parameters(model)
logger.info('encoder: %d' % enc)
logger.info('decoder: %d' % dec)
logger.info('* number of parameters: %d' % n_params)
_check_save_model_path(opt)

# Build optimizer.
Expand All @@ -107,13 +117,13 @@ def main(opt):
model_saver = build_model_saver(model_opt, opt, model, fields, optim)

trainer = build_trainer(
opt, model, fields, optim, data_type, model_saver=model_saver)
opt, model, fields, optim, data_type, logger, model_saver=model_saver)

def train_iter_fct(): return build_dataset_iter(
lazily_load_dataset("train", opt), fields, opt)
lazily_load_dataset("train", opt, logger), fields, opt)

def valid_iter_fct(): return build_dataset_iter(
lazily_load_dataset("valid", opt), fields, opt)
lazily_load_dataset("valid", opt, logger), fields, opt)

# Do training.
trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
Expand Down
Loading

0 comments on commit c75a06a

Please sign in to comment.