Skip to content

Commit

Permalink
[Feature] Replacing thread_wrapped_func with minimal mp.Process wrapp…
Browse files Browse the repository at this point in the history
…er (dmlc#2905)

* standardizing thread_wrapped_func

* lints

* Update __init__.py
  • Loading branch information
BarclayII authored May 14, 2021
1 parent a90296a commit caa6d60
Show file tree
Hide file tree
Showing 23 changed files with 70 additions and 327 deletions.
35 changes: 1 addition & 34 deletions benchmarks/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,6 @@

from functools import partial, reduce, wraps

import torch.multiprocessing as mp
from _thread import start_new_thread
import traceback


def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""

@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()

def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))

start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)

return decorated_function


def _download(url, path, filename):
fn = os.path.join(path, filename)
Expand Down Expand Up @@ -520,7 +487,7 @@ def _wrapper(func):
if not filter.check(func):
# skip if not enabled
func.benchmark_name = "skip_" + func.__name__
return thread_wrapped_func(func)
return func
return _wrapper

#####################################
Expand Down
14 changes: 14 additions & 0 deletions docs/source/api/python/dgl.multiprocessing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _apimultiprocessing:

dgl.multiprocessing
===================

This is a minimal wrapper of Python's native :mod:`multiprocessing` module.
It modifies the :class:`multiprocessing.Process` class to make forking
work with OpenMP in the DGL core library.

The API usage is exactly the same as the native module, so DGL does not provide
additional documentation.

In addition, if your backend is PyTorch, this module will also be compatible with
:mod:`torch.multiprocessing` module.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Welcome to Deep Graph Library Tutorials and Documentation
api/python/dgl.ops
api/python/dgl.optim
api/python/dgl.sampling
api/python/dgl.multiprocessing
api/python/udf

.. toctree::
Expand Down
7 changes: 2 additions & 5 deletions examples/pytorch/GATNE-T/src/main_sparse_multi_gpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from tqdm.auto import tqdm
from numpy import random
from torch.nn.parameter import Parameter
import dgl
import dgl.function as fn
import dgl.multiprocessing as mp

from utils import *

Expand Down Expand Up @@ -481,10 +481,7 @@ def train_model(network_data):
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(
target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data),
)
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
Expand Down
33 changes: 0 additions & 33 deletions examples/pytorch/GATNE-T/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,6 @@
import multiprocessing
from functools import partial, reduce, wraps

import torch.multiprocessing as mp
from _thread import start_new_thread
import traceback


def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""

@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()

def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))

start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)

return decorated_function


def parse_args():
parser = argparse.ArgumentParser()
Expand Down
35 changes: 3 additions & 32 deletions examples/pytorch/gcmc/train_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
import tqdm
import torch as th
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel
from _thread import start_new_thread
from functools import wraps
from data import MovieLens
from model import GCMCLayer, DenseBiDecoder, BiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue

class Net(nn.Module):
def __init__(self, args, dev_id):
Expand Down Expand Up @@ -136,33 +134,6 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
rmse = np.sqrt(rmse)
return rmse

# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))

start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function

def config():
parser = argparse.ArgumentParser(description='GCMC')
parser.add_argument('--seed', default=123, type=int)
Expand Down Expand Up @@ -409,7 +380,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
dataset.train_dec_graph.create_formats_()
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset))
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, dataset))
p.start()
procs.append(p)
for p in procs:
Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch/graphsage/train_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
Expand Down
33 changes: 1 addition & 32 deletions examples/pytorch/graphsage/train_cv_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
import traceback
import math
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -148,34 +146,6 @@ def sample_blocks(self, seeds):
hist_blocks.insert(0, hist_block)
return blocks, hist_blocks

# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
def thread_wrapped_func(func):
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))

start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function

def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
Expand Down Expand Up @@ -245,7 +215,6 @@ def update_history(g, blocks):
h_new = block.dstdata['h_new'].cpu()
g.ndata[hist_col][ids] = h_new

@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, data):
dropout = 0.2

Expand Down
6 changes: 2 additions & 4 deletions examples/pytorch/graphsage/train_sampling_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.nn.pytorch as dglnn
import time
import math
Expand All @@ -13,7 +13,6 @@
import tqdm

from model import SAGE
from utils import thread_wrapped_func
from load_graph import load_reddit, inductive_split

def compute_acc(pred, labels):
Expand Down Expand Up @@ -217,8 +216,7 @@ def run(proc_id, n_gpus, args, devices, data):
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data))
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
Expand Down
6 changes: 2 additions & 4 deletions examples/pytorch/graphsage/train_sampling_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
Expand All @@ -13,7 +13,6 @@
from torch.nn.parallel import DistributedDataParallel
import tqdm

from utils import thread_wrapped_func
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler

Expand Down Expand Up @@ -191,8 +190,7 @@ def main(args, devices):
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data))
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
Expand Down
2 changes: 0 additions & 2 deletions examples/pytorch/ogb/cluster-sage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
import tqdm
import traceback
Expand Down
5 changes: 2 additions & 3 deletions examples/pytorch/ogb/deepwalk/deepwalk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import argparse
import dgl
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
from torch.utils.data import DataLoader
import os
import random
Expand All @@ -10,7 +10,7 @@

from reading_data import DeepwalkDataset
from model import SkipGramModel
from utils import thread_wrapped_func, shuffle_walks, sum_up_params
from utils import shuffle_walks, sum_up_params

class DeepwalkTrainer:
def __init__(self, args):
Expand Down Expand Up @@ -110,7 +110,6 @@ def fast_train_mp(self):
else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)

@thread_wrapped_func
def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """
if self.args.mix:
Expand Down
6 changes: 2 additions & 4 deletions examples/pytorch/ogb/deepwalk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from torch.nn import init
import random
import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue

from utils import thread_wrapped_func

def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings
Expand Down Expand Up @@ -110,7 +109,6 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):

return grad

@thread_wrapped_func
def async_update(num_threads, model, queue):
""" asynchronous embedding update """
torch.set_num_threads(num_threads)
Expand Down
Loading

0 comments on commit caa6d60

Please sign in to comment.