forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add c10d dynamic loading mechanism and unit test (pytorch#28068)
Summary: The original behavior of pytorch c10d only supports built-in c10d backends, such as nccl/gloo/mpi. This patch is used to extend the c10d capability to support dynamically loading 3rd party communication libraries which are derived from ProcessGroup base class. related RFC is in: pytorch#27955 Through this way, user just need specify a 3rd party c10d backend name when invoking torch.distributed.init_process_group(). The proposed logic will try to load corresponding c10d backend cpp extension automatically. as for how to develop a new 3rd party c10d backend through cpp extension, pls refer to test/cpp_extensions/cpp_c10d_extension.cpp Pull Request resolved: pytorch#28068 Differential Revision: D19174838 Pulled By: agolynski fbshipit-source-id: 3409a504a43ce7260e6f9d1207c00e87471fac62
- Loading branch information
1 parent
2a4ca70
commit 762270c
Showing
8 changed files
with
361 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#include "cpp_c10d_extension.hpp" | ||
|
||
#include <map> | ||
|
||
namespace c10d { | ||
|
||
ProcessGroupTest::WorkTest::~WorkTest() {} | ||
|
||
bool ProcessGroupTest::WorkTest::isCompleted() { | ||
return true; | ||
} | ||
|
||
bool ProcessGroupTest::WorkTest::isSuccess() const { | ||
return true; | ||
} | ||
|
||
bool ProcessGroupTest::WorkTest::wait() { | ||
return true; | ||
} | ||
|
||
ProcessGroupTest::ProcessGroupTest(int rank, int size) | ||
: ProcessGroup(rank, size) {} | ||
|
||
ProcessGroupTest::~ProcessGroupTest() {} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::broadcast( | ||
std::vector<at::Tensor>& tensors, | ||
const BroadcastOptions& opts) { | ||
return std::make_shared<ProcessGroupTest::WorkTest>(); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce( | ||
std::vector<at::Tensor>& tensors, | ||
const AllreduceOptions& opts) { | ||
return std::make_shared<ProcessGroupTest::WorkTest>(); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce_coalesced( | ||
std::vector<at::Tensor>& tensors, | ||
const AllreduceCoalescedOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce( | ||
std::vector<at::Tensor>& tensors, | ||
const ReduceOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support reduce"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather( | ||
std::vector<std::vector<at::Tensor>>& outputTensors, | ||
std::vector<at::Tensor>& inputTensors, | ||
const AllgatherOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support allgather"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather_base( | ||
at::Tensor& outputBuffer, | ||
at::Tensor& inputBuffer, | ||
const AllgatherOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support allgather_base"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::barrier( | ||
const BarrierOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support barrier"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::gather( | ||
std::vector<std::vector<at::Tensor>>& outputTensors, | ||
std::vector<at::Tensor>& inputTensors, | ||
const GatherOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support gather"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::scatter( | ||
std::vector<at::Tensor>& outputTensors, | ||
std::vector<std::vector<at::Tensor>>& inputTensors, | ||
const ScatterOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support scatter"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce_scatter( | ||
std::vector<at::Tensor>& outputTensors, | ||
std::vector<std::vector<at::Tensor>>& inputTensors, | ||
const ReduceScatterOptions& opts) { | ||
throw std::runtime_error("ProcessGroupTest does not support reduce_scatter"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::send( | ||
std::vector<at::Tensor>& tensors, | ||
int dstRank, | ||
int tag) { | ||
throw std::runtime_error("ProcessGroupTest does not support send"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recv( | ||
std::vector<at::Tensor>& tensors, | ||
int srcRank, | ||
int tag) { | ||
throw std::runtime_error("ProcessGroupTest does not support recv"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recvAnysource( | ||
std::vector<at::Tensor>& tensor, | ||
int tag) { | ||
throw std::runtime_error("ProcessGroupTest does not support recvAnysource"); | ||
} | ||
|
||
std::shared_ptr<ProcessGroup> ProcessGroupTest::createProcessGroupTest( | ||
const std::shared_ptr<::c10d::Store>& store, | ||
int rank, | ||
int size, | ||
const std::chrono::duration<float>& timeout) { | ||
return std::make_shared<ProcessGroupTest>(rank, size); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("createProcessGroupTest", &ProcessGroupTest::createProcessGroupTest); | ||
} | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#pragma once | ||
|
||
#include <torch/extension.h> | ||
|
||
#include <deque> | ||
#include <exception> | ||
#include <memory> | ||
#include <mutex> | ||
#include <thread> | ||
#include <vector> | ||
|
||
#include <pybind11/chrono.h> | ||
|
||
#include <c10d/ProcessGroup.hpp> | ||
#include <c10d/Store.hpp> | ||
#include <c10d/Types.hpp> | ||
#include <c10d/Utils.hpp> | ||
|
||
namespace c10d { | ||
|
||
// | ||
// ProcessGroupTest implements dummy bindings for c10d. | ||
// | ||
|
||
class ProcessGroupTest : public ProcessGroup { | ||
public: | ||
class WorkTest : public ProcessGroup::Work { | ||
public: | ||
WorkTest() {} | ||
|
||
virtual ~WorkTest(); | ||
bool isCompleted() override; | ||
bool isSuccess() const override; | ||
bool wait() override; | ||
|
||
protected: | ||
friend class ProcessGroupTest; | ||
}; | ||
|
||
explicit ProcessGroupTest(int rank = -1, int size = -1); | ||
virtual ~ProcessGroupTest(); | ||
|
||
std::shared_ptr<ProcessGroup::Work> broadcast( | ||
std::vector<at::Tensor>& data, | ||
const BroadcastOptions& opts = BroadcastOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> allreduce( | ||
std::vector<at::Tensor>& tensors, | ||
const AllreduceOptions& opts = AllreduceOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> allreduce_coalesced( | ||
std::vector<at::Tensor>& tensors, | ||
const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> reduce( | ||
std::vector<at::Tensor>& tensors, | ||
const ReduceOptions& opts = ReduceOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> allgather( | ||
std::vector<std::vector<at::Tensor>>& outputTensors, | ||
std::vector<at::Tensor>& inputTensors, | ||
const AllgatherOptions& opts = AllgatherOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> allgather_base( | ||
at::Tensor& outputBuffer, | ||
at::Tensor& inputBuffer, | ||
const AllgatherOptions& opts = AllgatherOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> barrier( | ||
const BarrierOptions& opts = BarrierOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> gather( | ||
std::vector<std::vector<at::Tensor>>& outputTensors, | ||
std::vector<at::Tensor>& inputTensors, | ||
const GatherOptions& opts = GatherOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> scatter( | ||
std::vector<at::Tensor>& outputTensors, | ||
std::vector<std::vector<at::Tensor>>& inputTensors, | ||
const ScatterOptions& opts = ScatterOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> reduce_scatter( | ||
std::vector<at::Tensor>& outputTensors, | ||
std::vector<std::vector<at::Tensor>>& inputTensors, | ||
const ReduceScatterOptions& opts = ReduceScatterOptions()) override; | ||
|
||
std::shared_ptr<ProcessGroup::Work> send( | ||
std::vector<at::Tensor>& tensors, | ||
int dstRank, | ||
int tag); | ||
|
||
std::shared_ptr<ProcessGroup::Work> recv( | ||
std::vector<at::Tensor>& tensors, | ||
int srcRank, | ||
int tag); | ||
|
||
std::shared_ptr<ProcessGroup::Work> recvAnysource( | ||
std::vector<at::Tensor>& tensor, | ||
int tag); | ||
|
||
// Create a new ProcessGroupTest instance | ||
static std::shared_ptr<ProcessGroup> createProcessGroupTest( | ||
const std::shared_ptr<::c10d::Store>& store, | ||
int rank, | ||
int size, | ||
const std::chrono::duration<float>& timeout); | ||
|
||
static void ProcessGroupTestConstructor() __attribute__((constructor)) { | ||
py::object module = py::module::import("torch.distributed"); | ||
py::object register_backend = module.attr("Backend").attr("register_backend"); | ||
// The first parameter is the backend name used by user in invoking | ||
// torch.distributed.init_process_group(). | ||
// Note it could be different with module name. For example, the module | ||
// name is "torch_test" but the backend name is "test". | ||
// The second parameter is the instantiation function. | ||
register_backend("test", py::cpp_function(createProcessGroupTest)); | ||
} | ||
|
||
}; | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.