Skip to content

Commit

Permalink
Add pickling support for WorkerInfo (pytorch#73371)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#73371

This PRs allows for the pybinded class `WorkerInfo` to be pickled. The class is pickled into a tuple of worker_name and rank in format `(NAME, ID)`. This allows WorkerInfo to be passed as an argument for RPC calls.

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D34458153

Pulled By: H-Huang

fbshipit-source-id: 7b8f99960bdc0e24021e252d8c8138bcb53f698c
(cherry picked from commit 8fb119b)
  • Loading branch information
H-Huang authored and pytorchmergebot committed Feb 28, 2022
1 parent f437ca6 commit 6c8e516
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
25 changes: 20 additions & 5 deletions torch/csrc/distributed/rpc/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,26 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
// c10::hash, so we need to use the qualified name
// py::detail::hash, which unfortunately is in a detail namespace.
.def(py::detail::hash(py::self)) // NOLINT
.def("__repr__", [](const WorkerInfo& workerInfo) {
std::ostringstream os;
os << workerInfo;
return os.str();
});
.def(
"__repr__",
[](const WorkerInfo& workerInfo) {
std::ostringstream os;
os << workerInfo;
return os.str();
})
.def(py::pickle(
/* __getstate__ */
[](const WorkerInfo& workerInfo) {
return py::make_tuple(workerInfo.name_, workerInfo.id_);
},
/* __setstate__ */
[](py::tuple t) {
TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state.");

WorkerInfo info(
t[0].cast<std::string>(), t[1].cast<worker_id_t>());
return info;
}));

auto rpcAgent =
shared_ptr_class_<RpcAgent>(module, "RpcAgent")
Expand Down
16 changes: 13 additions & 3 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.distributed as dist
import torch.distributed.rpc as rpc
import torch.distributed.autograd as dist_autograd
from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info
from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info, WorkerInfo
from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs, _use_rpc_pickler, _thread_local_var, _wait_all
from torch.distributed.rpc.internal import (
PythonUDF,
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self, world_size):

def get_worker_infos(self):
return {
rpc.WorkerInfo(name=worker_name(rank), id=rank)
WorkerInfo(name=worker_name(rank), id=rank)
for rank in range(self.world_size)
}

Expand Down Expand Up @@ -277,6 +277,9 @@ def delayed_add(a, b, seconds=0.05):
return a + b


def identity(a):
return a

def no_result():
print("do nothing")

Expand Down Expand Up @@ -1377,7 +1380,6 @@ def test_world_size_one(self):

@dist_init(setup_rpc=False)
def test_invalid_names(self):
from torch.distributed.rpc import WorkerInfo

worker_id = 0
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
Expand All @@ -1394,6 +1396,14 @@ def test_invalid_names(self):
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
info = WorkerInfo("".join(["a" for i in range(500)]), worker_id)

# Test that WorkerInfo can be pickled and sent in RPC call
@dist_init
def test_worker_info_pickle(self):
dst_rank = (self.rank + 1) % self.world_size
worker_info = rpc.api.get_worker_info()
ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,))
self.assertEqual(ret, worker_info)

@dist_init
def test_add(self):
n = self.rank + 1
Expand Down

0 comments on commit 6c8e516

Please sign in to comment.