|
| 1 | +Customize Process Group Backends Using Cpp Extensions |
| 2 | +===================================================== |
| 3 | + |
| 4 | +**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`__ |
| 5 | + |
| 6 | + |
| 7 | +Prerequisites: |
| 8 | + |
| 9 | +- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__ |
| 10 | +- `PyTorch Collective Communication Package <https://pytorch.org/docs/stable/distributed.html>`__ |
| 11 | +- `Writing Distributed Applications with PyTorch <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__ |
| 12 | + |
| 13 | +This tutorial demonstrates how to implement a custom ``ProcessGroup`` |
| 14 | +backend and plug that into |
| 15 | +`PyTorch distributed package <https://pytorch.org/docs/stable/distributed.html>`__ using |
| 16 | +`cpp extensions <https://pytorch.org/docs/stable/cpp_extension.html>`__. This is helpful when you need a specialized software |
| 17 | +stack for your hardware, or when you would like to experiment with new |
| 18 | +collective communication algorithms. |
| 19 | + |
| 20 | + |
| 21 | +Basics |
| 22 | +------ |
| 23 | + |
| 24 | +PyTorch collective communications power several widely adopted distributed |
| 25 | +training features, including |
| 26 | +`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__, |
| 27 | +`ZeroRedundancyOptimizer <https://pytorch.org/docs/stable/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer>`__, |
| 28 | +`FullyShardedDataParallel <https://github.com/pytorch/pytorch/blob/master/torch/distributed/_fsdp/fully_sharded_data_parallel.py>`__,. |
| 29 | +In order to make the same collective communication API work with |
| 30 | +different communication backends, the distributed package abstracts collective |
| 31 | +communication operations into a |
| 32 | +`ProcessGroup <https://github.com/pytorch/pytorch/blob/release/1.10/torch/csrc/distributed/c10d/ProcessGroup.hpp>`__ |
| 33 | +class. Different backends can |
| 34 | +then be implemented as subclasses of ``ProcessGroup`` using preferred |
| 35 | +third-party libraries. PyTorch distributed comes with three default backends, |
| 36 | +``ProcessGroupNCCL``, ``ProcessGroupGloo``, and ``ProcessGroupMPI``. However, |
| 37 | +beyond these three backends, there are also other communication libraries |
| 38 | +(e.g., `UCC <https://github.com/openucx/ucc>`__, |
| 39 | +`OneCCL <https://github.com/oneapi-src/oneCCL>`__), different types of hardware |
| 40 | +(e.g., `TPU <https://cloud.google.com/tpu>`__, |
| 41 | +`Trainum <https://aws.amazon.com/machine-learning/trainium/>`__), and emerging |
| 42 | +communication algorithms (e.g., |
| 43 | +`Herring <https://www.amazon.science/publications/herring-rethinking-the-parameter-server-at-scale-for-the-cloud>`__, |
| 44 | +`Reduction Server <https://cloud.google.com/blog/topics/developers-practitioners/optimize-training-performance-reduction-server-vertex-ai>`__). |
| 45 | +Therefore, the distributed package exposed extension APIs to allow customizing |
| 46 | +collective communication backends. |
| 47 | + |
| 48 | + |
| 49 | +The 4 steps below show how to implement a dummy collective communication backend |
| 50 | +and use that in Python application code. Please note that this tutorial focuses |
| 51 | +on demonstrating the extension APIs, instead of developing a functioning |
| 52 | +communication backend. Hence, the ``dummy`` backend just covers a subset of the |
| 53 | +APIs (``all_reduce`` and ``all_gather``), and simply sets the values of tensors |
| 54 | +to 0. |
| 55 | + |
| 56 | + |
| 57 | +Step 1: Implement a Subclass of ``ProcessGroup`` |
| 58 | +------------------------------------------------ |
| 59 | + |
| 60 | +This first step is to implement a ``ProcessGroup`` subclass that overrides |
| 61 | +target collective communication APIs and runs the custom communication algorithm. |
| 62 | +The extension also needs to implement a ``ProcessGroup::Work`` subclass, which |
| 63 | +serves as a future of communication results and allows asynchronous execution in |
| 64 | +application code. If the extension uses third-party libraries, it can |
| 65 | +include the headers and call into the library APIs from the ``ProcessGroupDummy`` |
| 66 | +subclass. The two code blocks below show the implementation of ``dummy.h`` and |
| 67 | +``dummy.cpp``. See the `dummy collectives <https://github.com/mrshenli/dummy_collectives>`__ |
| 68 | +repository for more details. |
| 69 | + |
| 70 | +.. code-block:: cpp |
| 71 | +
|
| 72 | + // file name: dummy.hpp |
| 73 | + #include <torch/python.h> |
| 74 | +
|
| 75 | + #include <c10d/ProcessGroup.hpp> |
| 76 | + #include <c10d/Store.hpp> |
| 77 | + #include <c10d/Types.hpp> |
| 78 | + #include <c10d/Utils.hpp> |
| 79 | +
|
| 80 | + #include <pybind11/chrono.h> |
| 81 | +
|
| 82 | + namespace c10d { |
| 83 | +
|
| 84 | + class ProcessGroupDummy : public ProcessGroup { |
| 85 | + public: |
| 86 | +
|
| 87 | + class WorkDummy : public ProcessGroup::Work { |
| 88 | + public: |
| 89 | + WorkDummy( |
| 90 | + OpType opType, |
| 91 | + c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output |
| 92 | + : ProcessGroup::Work( |
| 93 | + -1, // rank, only used by recvAnySource, irrelevant in this demo |
| 94 | + opType), |
| 95 | + future_(std::move(future)) {} |
| 96 | + bool isCompleted() override; |
| 97 | + bool isSuccess() const override; |
| 98 | + bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; |
| 99 | + c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; |
| 100 | +
|
| 101 | + private: |
| 102 | + c10::intrusive_ptr<c10::ivalue::Future> future_; |
| 103 | + }; |
| 104 | +
|
| 105 | + ProcessGroupDummy(int rank, int size); |
| 106 | +
|
| 107 | + c10::intrusive_ptr<ProcessGroup::Work> allgather( |
| 108 | + std::vector<std::vector<at::Tensor>>& outputTensors, |
| 109 | + std::vector<at::Tensor>& inputTensors, |
| 110 | + const AllgatherOptions& opts = AllgatherOptions()) override; |
| 111 | +
|
| 112 | + c10::intrusive_ptr<ProcessGroup::Work> allreduce( |
| 113 | + std::vector<at::Tensor>& tensors, |
| 114 | + const AllreduceOptions& opts = AllreduceOptions()) override; |
| 115 | +
|
| 116 | + // The collective communication APIs without a custom implementation |
| 117 | + // will error out if invoked by application code. |
| 118 | + }; |
| 119 | + } // namespace c10d |
| 120 | +
|
| 121 | +
|
| 122 | +.. code-block:: cpp |
| 123 | +
|
| 124 | + // file name: dummy.cpp |
| 125 | + #include "dummy.hpp" |
| 126 | +
|
| 127 | + namespace c10d { |
| 128 | +
|
| 129 | + // This is a dummy allgather that sets all output tensors to zero |
| 130 | + // Modify the implementation to conduct real communication asynchronously |
| 131 | + c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allgather( |
| 132 | + std::vector<std::vector<at::Tensor>>& outputTensors, |
| 133 | + std::vector<at::Tensor>& inputTensors, |
| 134 | + const AllgatherOptions& /* unused */) { |
| 135 | + for (auto& outputTensorVec : outputTensors) { |
| 136 | + for (auto& outputTensor : outputTensorVec) { |
| 137 | + outputTensor.zero_(); |
| 138 | + } |
| 139 | + } |
| 140 | +
|
| 141 | + auto future = c10::make_intrusive<c10::ivalue::Future>( |
| 142 | + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); |
| 143 | + future->markCompleted(c10::IValue(outputTensors)); |
| 144 | + return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future)); |
| 145 | + } |
| 146 | +
|
| 147 | + // This is a dummy allreduce that sets all output tensors to zero |
| 148 | + // Modify the implementation to conduct real communication asynchronously |
| 149 | + c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allreduce( |
| 150 | + std::vector<at::Tensor>& tensors, |
| 151 | + const AllreduceOptions& opts) { |
| 152 | + for (auto& tensor : tensors) { |
| 153 | + tensor.zero_(); |
| 154 | + } |
| 155 | +
|
| 156 | + auto future = c10::make_intrusive<c10::ivalue::Future>( |
| 157 | + c10::ListType::create(c10::TensorType::get())); |
| 158 | + future->markCompleted(c10::IValue(tensors)); |
| 159 | + return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future)); |
| 160 | + } |
| 161 | + } // namespace c10d |
| 162 | +
|
| 163 | +Step 2: Expose The Extension Python APIs |
| 164 | +---------------------------------------- |
| 165 | + |
| 166 | +The backend constructors are called |
| 167 | +`from the Python side <https://github.com/pytorch/pytorch/blob/v1.9.0/torch/distributed/distributed_c10d.py#L643-L650>`__, |
| 168 | +so the extension also needs to expose the constructor APIs to Python. This can |
| 169 | +be done by adding the following methods. In this example, store and timeout are |
| 170 | +not passed to the ``ProcessGroupDummy`` instance, as those are not needed in |
| 171 | +this dummy implementation. However, real extensions should consider supporting |
| 172 | +the timeout argument. |
| 173 | + |
| 174 | +.. code-block:: cpp |
| 175 | +
|
| 176 | + class ProcessGroupDummy : public ProcessGroup { |
| 177 | + static c10::intrusive_ptr<ProcessGroup> createProcessGroupDummy( |
| 178 | + const c10::intrusive_ptr<::c10d::Store>& store, |
| 179 | + int rank, |
| 180 | + int size, |
| 181 | + const std::chrono::duration<float>& timeout); |
| 182 | +
|
| 183 | + static void ProcessGroupDummyConstructor() __attribute__((constructor)) { |
| 184 | + py::object module = py::module::import("torch.distributed"); |
| 185 | + py::object register_backend = |
| 186 | + module.attr("Backend").attr("register_backend"); |
| 187 | + register_backend("dummy", py::cpp_function(createProcessGroupDummy)); |
| 188 | + } |
| 189 | + } |
| 190 | +
|
| 191 | +.. code-block:: cpp |
| 192 | +
|
| 193 | + c10::intrusive_ptr<ProcessGroup> ProcessGroupDummy::createProcessGroupDummy( |
| 194 | + const c10::intrusive_ptr<::c10d::Store>& /* unused */, |
| 195 | + int rank, |
| 196 | + int size, |
| 197 | + const std::chrono::duration<float>& /* unused */) { |
| 198 | + return c10::make_intrusive<ProcessGroupDummy>(rank, size); |
| 199 | + } |
| 200 | +
|
| 201 | + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| 202 | + m.def("createProcessGroupDummy", &ProcessGroupDummy::createProcessGroupDummy); |
| 203 | + } |
| 204 | +
|
| 205 | +
|
| 206 | +Step 3: Build The Custom Extension |
| 207 | +---------------------------------- |
| 208 | + |
| 209 | +Now, the extension source code files are ready. We can then use |
| 210 | +`cpp extensions <https://pytorch.org/docs/stable/cpp_extension.html>`__ |
| 211 | +to build it. Create a ``setup.py`` file that prepares the paths and commands. |
| 212 | +Then call ``python setup.py install`` to install the extension. |
| 213 | + |
| 214 | +If the extension depends on third-party libraries, you can also specify |
| 215 | +``libraries_dirs`` and ``libraries`` to the cpp extension APIs. See the |
| 216 | +`torch ucc <https://github.com/openucx/torch-ucc>`__ |
| 217 | +project as a real-world example. |
| 218 | + |
| 219 | +.. code-block:: python |
| 220 | +
|
| 221 | + # file name: setup.py |
| 222 | + import os |
| 223 | + import sys |
| 224 | + import torch |
| 225 | + from setuptools import setup |
| 226 | + from torch.utils import cpp_extension |
| 227 | +
|
| 228 | + sources = ["src/dummy.cpp"] |
| 229 | + include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"] |
| 230 | +
|
| 231 | + if torch.cuda.is_available(): |
| 232 | + module = cpp_extension.CUDAExtension( |
| 233 | + name = "dummy_collectives", |
| 234 | + sources = sources, |
| 235 | + include_dirs = include_dirs, |
| 236 | + ) |
| 237 | + else: |
| 238 | + module = cpp_extension.CppExtension( |
| 239 | + name = "dummy_collectives", |
| 240 | + sources = sources, |
| 241 | + include_dirs = include_dirs, |
| 242 | + ) |
| 243 | +
|
| 244 | + setup( |
| 245 | + name = "Dummy-Collectives", |
| 246 | + version = "0.0.1", |
| 247 | + ext_modules = [module], |
| 248 | + cmdclass={'build_ext': cpp_extension.BuildExtension} |
| 249 | + ) |
| 250 | +
|
| 251 | +Step 4: Use The Extension in Application |
| 252 | +---------------------------------------- |
| 253 | + |
| 254 | +After install, you can conveniently use the dummy backend when calling |
| 255 | +`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`_. |
| 256 | + |
| 257 | +.. code-block:: python |
| 258 | +
|
| 259 | + import os |
| 260 | +
|
| 261 | + import torch |
| 262 | + import dummy_collectives |
| 263 | +
|
| 264 | + import torch.distributed as dist |
| 265 | +
|
| 266 | + os.environ['MASTER_ADDR'] = 'localhost' |
| 267 | + os.environ['MASTER_PORT'] = '29500' |
| 268 | +
|
| 269 | + dist.init_process_group("dummy", rank=0, world_size=1) |
| 270 | +
|
| 271 | + x = torch.ones(6) |
| 272 | + dist.all_reduce(x) |
| 273 | + y = x.cuda() |
| 274 | + dist.all_reduce(y) |
| 275 | +
|
| 276 | + print(f"cpu allreduce: {x}") |
| 277 | + print(f"cuda allreduce: {y}") |
| 278 | +
|
| 279 | + try: |
| 280 | + dist.broadcast(x, 0) |
| 281 | + except RuntimeError: |
| 282 | + print("got RuntimeError as broadcast is not supported") |
0 commit comments