Skip to content

Commit

Permalink
add c10d dynamic loading mechanism and unit test (pytorch#28068)
Browse files Browse the repository at this point in the history
Summary:
The original behavior of pytorch c10d only supports built-in c10d backends, such as
nccl/gloo/mpi. This patch is used to extend the c10d capability to support dynamically
loading 3rd party communication libraries which are derived from ProcessGroup base class.

related RFC is in: pytorch#27955

Through this way, user just need specify a 3rd party c10d backend name when invoking
torch.distributed.init_process_group(). The proposed logic will try to load corresponding
c10d backend cpp extension automatically. as for how to develop a new 3rd party c10d backend
through cpp extension, pls refer to test/cpp_extensions/cpp_c10d_extension.cpp
Pull Request resolved: pytorch#28068

Differential Revision: D19174838

Pulled By: agolynski

fbshipit-source-id: 3409a504a43ce7260e6f9d1207c00e87471fac62
  • Loading branch information
ftian1 authored and facebook-github-bot committed Apr 2, 2020
1 parent 2a4ca70 commit 762270c
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 4 deletions.
22 changes: 21 additions & 1 deletion docs/source/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Distributed communication package - torch.distributed
Backends
--------

``torch.distributed`` supports three backends, each with
``torch.distributed`` supports three built-in backends, each with
different capabilities. The table below shows which functions are available
for use with CPU / CUDA tensors.
MPI supports CUDA only if the implementation used to build PyTorch supports it.
Expand Down Expand Up @@ -408,6 +408,26 @@ of 16

.. _distributed-launch:

Third-party backends
--------------------

Besides the GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends
through a run-time register mechanism.
For references on how to develop a third-party backend through C++ Extension,
please refer to `Tutorials - Custom C++ and CUDA Extensions <https://pytorch.org/
tutorials/advanced/cpp_extension.html>`_ and `test/cpp_extensions/cpp_c10d_extension.cpp`.
The capability of third-party backends are decided by their own implementations.

The new backend derives from `c10d.ProcessGroup` and registers the backend name and the
instantiating interface through :func:`torch.distributed.Backend.register_backend` when
imported.

When manually importing this backend and invoking :func:`torch.distributed.init_process_group`
with the corresponding backend name, the `torch.distributed` package runs on the new backend.

.. warning::
The support of third-party backend is experimental and subject to change.

Launch utility
--------------

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ def print_box(msg):
'include/c10/cuda/impl/*.h',
'include/c10/hip/*.h',
'include/c10/hip/impl/*.h',
'include/c10d/*.hpp',
'include/caffe2/**/*.h',
'include/torch/*.h',
'include/torch/csrc/*.h',
Expand Down
122 changes: 122 additions & 0 deletions test/cpp_extensions/cpp_c10d_extension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include "cpp_c10d_extension.hpp"

#include <map>

namespace c10d {

ProcessGroupTest::WorkTest::~WorkTest() {}

bool ProcessGroupTest::WorkTest::isCompleted() {
return true;
}

bool ProcessGroupTest::WorkTest::isSuccess() const {
return true;
}

bool ProcessGroupTest::WorkTest::wait() {
return true;
}

ProcessGroupTest::ProcessGroupTest(int rank, int size)
: ProcessGroup(rank, size) {}

ProcessGroupTest::~ProcessGroupTest() {}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
return std::make_shared<ProcessGroupTest::WorkTest>();
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
return std::make_shared<ProcessGroupTest::WorkTest>();
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support reduce");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support allgather");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support allgather_base");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::barrier(
const BarrierOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support barrier");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support gather");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support scatter");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support reduce_scatter");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
throw std::runtime_error("ProcessGroupTest does not support send");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
throw std::runtime_error("ProcessGroupTest does not support recv");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recvAnysource(
std::vector<at::Tensor>& tensor,
int tag) {
throw std::runtime_error("ProcessGroupTest does not support recvAnysource");
}

std::shared_ptr<ProcessGroup> ProcessGroupTest::createProcessGroupTest(
const std::shared_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout) {
return std::make_shared<ProcessGroupTest>(rank, size);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("createProcessGroupTest", &ProcessGroupTest::createProcessGroupTest);
}

} // namespace c10d
121 changes: 121 additions & 0 deletions test/cpp_extensions/cpp_c10d_extension.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#pragma once

#include <torch/extension.h>

#include <deque>
#include <exception>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>

#include <pybind11/chrono.h>

#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>

namespace c10d {

//
// ProcessGroupTest implements dummy bindings for c10d.
//

class ProcessGroupTest : public ProcessGroup {
public:
class WorkTest : public ProcessGroup::Work {
public:
WorkTest() {}

virtual ~WorkTest();
bool isCompleted() override;
bool isSuccess() const override;
bool wait() override;

protected:
friend class ProcessGroupTest;
};

explicit ProcessGroupTest(int rank = -1, int size = -1);
virtual ~ProcessGroupTest();

std::shared_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;

std::shared_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;

std::shared_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override;

std::shared_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;

std::shared_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;

std::shared_ptr<ProcessGroup::Work> allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;

std::shared_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;

std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;

std::shared_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;

std::shared_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;

std::shared_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag);

std::shared_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag);

std::shared_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensor,
int tag);

// Create a new ProcessGroupTest instance
static std::shared_ptr<ProcessGroup> createProcessGroupTest(
const std::shared_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout);

static void ProcessGroupTestConstructor() __attribute__((constructor)) {
py::object module = py::module::import("torch.distributed");
py::object register_backend = module.attr("Backend").attr("register_backend");
// The first parameter is the backend name used by user in invoking
// torch.distributed.init_process_group().
// Note it could be different with module name. For example, the module
// name is "torch_test" but the backend name is "test".
// The second parameter is the instantiation function.
register_backend("test", py::cpp_function(createProcessGroupTest));
}

};

} // namespace c10d
63 changes: 62 additions & 1 deletion test/distributed/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, find_free_port
from torch.distributed.distributed_c10d import _get_default_group
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
from torch.testing._internal.common_distributed import simple_sparse_reduce_tests, skip_if_rocm
Expand All @@ -31,6 +32,12 @@

skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

CPP_EXTENSIONS_WARNING = """
Ninja (https://ninja-build.org) must be available to run C++ extensions tests,
but it could not be found. Install ninja with `pip install ninja`
or `conda install ninja`.
"""

BACKEND = os.environ["BACKEND"]
TEMP_DIR = os.environ["TEMP_DIR"]
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
Expand Down Expand Up @@ -150,6 +157,21 @@ def wrapper(*args, **kwargs):
return wrapper


def skip_if_no_ninja(func):

@wraps(func)
def wrapper(*args, **kwargs):
try:
import torch.utils.cpp_extension
torch.utils.cpp_extension.verify_ninja_availability()
except RuntimeError:
print(CPP_EXTENSIONS_WARNING)
return 0

return func(*args, **kwargs)

return wrapper

def require_backend(backends):
if BACKEND not in backends:
return unittest.skip("Test requires backend to be one of %s" % backends)
Expand Down Expand Up @@ -2272,6 +2294,45 @@ def _join_and_reduce(self, fn):
class TestMPI(TestCase, _DistTestBase):
pass

elif BACKEND == "test":
class TestBackendDynamicLoad(TestCase):
def setUp(self):
super(TestBackendDynamicLoad, self).setUp()

def _load_test_backend(self):
temp_dir = tempfile.mkdtemp()
src = "{}/../cpp_extensions/cpp_c10d_extension.cpp".format(os.path.abspath(os.path.dirname(__file__)))
extension = torch.utils.cpp_extension.load(
name="torch_test",
sources=[src],
build_directory=temp_dir
)

@skip_if_no_ninja
def test_backend_apis(self):
self._load_test_backend()

os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(find_free_port())
os.environ['RANK'] = '0'

dist.init_process_group(backend='test', init_method='env://', world_size=1, rank=0)
self.assertEqual(dist.get_rank(), 0)
self.assertEqual(dist.get_world_size(), 1)

process_group = _get_default_group()
work = process_group.allreduce([torch.rand(1), torch.rand(1)])
self.assertTrue(work.wait())
self.assertTrue(work.is_completed())
self.assertTrue(work.is_success())

work = process_group.broadcast([torch.rand(1)])
self.assertTrue(work.wait())
self.assertTrue(work.is_completed())
self.assertTrue(work.is_success())

dist.destroy_process_group()

if __name__ == "__main__":
assert (
Expand Down
3 changes: 3 additions & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@


if dist.is_available():
DISTRIBUTED_TESTS_CONFIG['test'] = {
'WORLD_SIZE': '1'
}
if not TEST_WITH_ROCM and dist.is_mpi_available():
DISTRIBUTED_TESTS_CONFIG['mpi'] = {
'WORLD_SIZE': '3',
Expand Down
1 change: 1 addition & 0 deletions test/test_determination.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_torch_file(self):
self.assertEqual(
self.determined_tests(["torch/utils/cpp_extension.py"]),
[
"distributed/test_distributed",
"test_cpp_extensions_aot_ninja",
"test_cpp_extensions_aot_no_ninja",
"test_determination",
Expand Down
Loading

0 comments on commit 762270c

Please sign in to comment.