Skip to content

Commit

Permalink
new distrib
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Sep 17, 2018
1 parent 54d0c30 commit 9f014fe
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 31 deletions.
10 changes: 5 additions & 5 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def train_opts(parser):

# GPU
group.add_argument('-gpuid', default=[], nargs='+', type=int,
help="Use CUDA on the listed devices.")
group.add_argument('-gpu_rank', default=0, nargs='+', type=int,
help="Rank the current gpu device.")
group.add_argument('-device_id', default=0, nargs='+', type=int,
help="Rank the current gpu device.")
help="Deprecated see world_size and gpu_ranks.")
group.add_argument('-gpu_ranks', default=[], nargs='+', type=int,
help="list of ranks of each process.")
group.add_argument('-world_size', default=0, nargs='+', type=int,
help="total number of distributed processes.")
group.add_argument('-gpu_backend', default='nccl', nargs='+', type=str,
help="Type of torch distributed backend")
group.add_argument('-gpu_verbose_level', default=0, type=int,
Expand Down
20 changes: 10 additions & 10 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _tally_parameters(model):
return n_params, enc, dec


def training_opt_postprocessing(opt):
def training_opt_postprocessing(opt, device_id):
if opt.word_vec_size != -1:
opt.src_word_vec_size = opt.word_vec_size
opt.tgt_word_vec_size = opt.word_vec_size
Expand All @@ -50,11 +50,11 @@ def training_opt_postprocessing(opt):

opt.brnn = (opt.encoder_type == "brnn")

if opt.rnn_type == "SRU" and not opt.gpuid:
raise AssertionError("Using SRU requires -gpuid set.")
if opt.rnn_type == "SRU" and not opt.gpu_ranks:
raise AssertionError("Using SRU requires -gpu_ranks set.")

if torch.cuda.is_available() and not opt.gpuid:
logger.info("WARNING: You have a CUDA device, should run with -gpuid")
if torch.cuda.is_available() and not opt.gpu_ranks:
logger.info("WARNING: You have a CUDA device, should run with -gpu_ranks")

if opt.seed > 0:
torch.manual_seed(opt.seed)
Expand All @@ -65,17 +65,17 @@ def training_opt_postprocessing(opt):
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = True

if opt.gpuid:
torch.cuda.set_device(opt.device_id)
if device_id >= 0:
torch.cuda.set_device(device_id)
if opt.seed > 0:
# These ensure same initialization in multi gpu mode
torch.cuda.manual_seed(opt.seed)

return opt


def main(opt):
opt = training_opt_postprocessing(opt)
def main(opt, device_id):
opt = training_opt_postprocessing(opt, device_id)
init_logger(opt.log_file)
# Load checkpoint if we resume from a previous training.
if opt.train_from:
Expand Down Expand Up @@ -120,7 +120,7 @@ def main(opt):
model_saver = build_model_saver(model_opt, opt, model, fields, optim)

trainer = build_trainer(
opt, model, fields, optim, data_type, model_saver=model_saver)
opt, device_id, model, fields, optim, data_type, model_saver=model_saver)

def train_iter_fct(): return build_dataset_iter(
lazily_load_dataset("train", opt), fields, opt)
Expand Down
7 changes: 4 additions & 3 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from onmt.utils.logging import logger


def build_trainer(opt, model, fields, optim, data_type, model_saver=None):
def build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=None):
"""
Simplify `Trainer` creation based on user `opt`s*
Expand All @@ -40,8 +40,9 @@ def build_trainer(opt, model, fields, optim, data_type, model_saver=None):
shard_size = opt.max_generator_batches
norm_method = opt.normalization
grad_accum_count = opt.accum_count
n_gpu = len(opt.gpuid)
gpu_rank = opt.gpu_rank
n_gpu = opt.world_size
# TODO if no GPU device_id = -1
gpu_rank = opt.gpu_ranks[device_id]
gpu_verbose_level = opt.gpu_verbose_level

report_manager = onmt.utils.build_report_manager(opt)
Expand Down
20 changes: 10 additions & 10 deletions onmt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@
from onmt.utils.logging import logger


def is_master(opt):
return opt.gpu_rank == 0
def is_master(opt, device_id):
return opt.gpu_ranks[device_id] == 0


def multi_init(opt):
if len(opt.gpuid) == 1:
raise ValueError('Cannot initialize multiprocess with one gpu only')
def multi_init(opt, device_id):
# if len(opt.gpuid) == 1:
# raise ValueError('Cannot initialize multiprocess with one gpu only')
dist_init_method = 'tcp://localhost:10000'
dist_world_size = len(opt.gpuid)
dist_world_size = opt.world_size
torch.distributed.init_process_group(
backend=opt.gpu_backend, init_method=dist_init_method,
world_size=dist_world_size, rank=opt.gpu_rank)
opt.gpu_rank = torch.distributed.get_rank()
if not is_master(opt):
world_size=dist_world_size, rank=opt.gpu_ranks[device_id])
gpu_rank = torch.distributed.get_rank()
if not is_master(opt, device_id):
logger.disabled = True

return opt.gpu_rank
return gpu_rank


def all_reduce_and_rescale_tensors(tensors, rescale_denom,
Expand Down
76 changes: 73 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,78 @@ def main(opt):
raise AssertionError("BPTT is not compatible with -accum > 1")

if len(opt.gpuid) > 1:
multi_main(opt)
else:
single_main(opt)
raise AssertionError("gpuid is deprecated see world_size and gpu_ranks")

nb_gpu = len(opt.gpu_ranks)
mp = torch.multiprocessing.get_context('spawn')

# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)

# Train with multiprocessing.
procs = []
# TODO case if no gpu
for i in range(nb_gpu):
gpu_rank = opt.gpu_ranks[i]
device_id = i

procs.append(mp.Process(target=run, args=(
opt, device_id, error_queue, ), daemon=True))
procs[i].start()
logger.info(" Starting process pid: %d " % procs[i].pid)
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()


def run(opt, device_id, error_queue):
""" run process """
try:
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
single_main(opt, device_id)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((gpu_rank, traceback.format_exc()))


class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""

def __init__(self, error_queue):
""" init error handler """
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(
target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)

def add_child(self, pid):
""" error handler """
self.children_pids.append(pid)

def error_listener(self):
""" error listener """
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)

def signal_handler(self, signalnum, stackframe):
""" signal handler """
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = """\n\n-- Tracebacks above this line can probably
be ignored --\n\n"""
msg += original_trace
raise Exception(msg)


if __name__ == "__main__":
Expand All @@ -38,3 +107,4 @@ def main(opt):

opt = parser.parse_args()
main(opt)

0 comments on commit 9f014fe

Please sign in to comment.