Skip to content

Conversation

mrshenli
Copy link
Contributor

No description provided.

Customize Process Group Backends Using Cpp Extensions
=====================================================

**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`_-
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ftian1, do you mind if I add you as the first author of this tutorial? BTW, thanks a lot for contributing this feature!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem:) It's my pleasure to be in here

@netlify
Copy link

netlify bot commented Jan 20, 2022

✔️ Deploy Preview for pytorch-tutorials-preview ready!

🔨 Explore the source changes: be4f14f

🔍 Inspect the deploy log: https://app.netlify.com/sites/pytorch-tutorials-preview/deploys/61f97d8ed1647e00088a89a1

😎 Browse the preview: https://deploy-preview-1798--pytorch-tutorials-preview.netlify.app

@mrshenli mrshenli force-pushed the c10d_extension branch 4 times, most recently from f5ae99d to 689a59e Compare January 23, 2022 02:39
@mrshenli mrshenli changed the title [WIP] Add tutorial for ProcessGroup extensions Add tutorial for ProcessGroup extensions Jan 23, 2022
@mrshenli mrshenli changed the title Add tutorial for ProcessGroup extensions Add a tutorial for ProcessGroup extensions Jan 23, 2022
communication algorithms (e.g.,
`Herring <https://www.amazon.science/publications/herring-rethinking-the-parameter-server-at-scale-for-the-cloud>`__,
`Reduction Server <https://cloud.google.com/blog/topics/developers-practitioners/optimize-training-performance-reduction-server-vertex-ai>`__).
Therefore, the distributed package exposed extension APIs to allow customizing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exposed -> exposes?

future_(std::move(future)) {}
bool isCompleted() override;
bool isSuccess() const override;
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it fine for tutorial purposes for these to not be implemented anywhere in the tutorial?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, let me remove those and add a comment to mention full implementation is in the repo

**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`__


Prerequisites:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add "cpp extensions" as a prerequisite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes, let me add that

import os

import torch
import dummy_collectives
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment specifying this is what imports the "dummy collectives" cpp extension and makes "dummy" backend available? I missed it at first and was wondering how the "dummy" name gets recognized but it seems to be through this.

Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for putting this together! A couple minor suggestions/comments

py::object module = py::module::import("torch.distributed");
py::object register_backend =
module.attr("Backend").attr("register_backend");
register_backend("dummy", py::cpp_function(createProcessGroupDummy));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention that this calls torch.distributed.register_backend which is how torch.distributed will recognize it as a valid backend?

(e.g., `TPU <https://cloud.google.com/tpu>`__,
`Trainum <https://aws.amazon.com/machine-learning/trainium/>`__), and emerging
communication algorithms (e.g.,
`Herring <https://www.amazon.science/publications/herring-rethinking-the-parameter-server-at-scale-for-the-cloud>`__,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Herring and Reduction server, is the best way to achieve implementing it in PyTorch through custom cpp extension or do we want to encourage users to build on top of the existing torch.distributed collectives which should be able to enable these algorithms on top of nccl or gloo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most efficient way would be doing it directly in the communication layer, i.e., through c10d extension. This is also how Fairing (not sure about Herring as it doesn't seem it's open source) and Reduction Server are implemented today (Faring uses c10d extension, and Reduction Server is NCCL plugin). Since the goal of those algorithms is to bump up comm efficiency, I would assume future users would follow similar paths, unless the algorithm is powerful enough to shine even with an inefficient implementation.

training features, including
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__,
`ZeroRedundancyOptimizer <https://pytorch.org/docs/stable/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer>`__,
`FullyShardedDataParallel <https://github.com/pytorch/pytorch/blob/master/torch/distributed/_fsdp/fully_sharded_data_parallel.py>`__,.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ",." at the end of the link -> "."

@mrshenli
Copy link
Contributor Author

Hey @brianjo, the content for this tutorial is ready to be merged. The failure on "pytorch_tutorial_pr_build_manager" look irrelevant? Is so, share we merge? Thanks!

@brianjo
Copy link
Contributor

brianjo commented Jan 24, 2022

Its needs to pass tests or it will break the build. I'll take a look today. Thanks!

@holly1238 holly1238 merged commit 1db1c1c into pytorch:master Feb 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants