diff --git a/test/run_test.sh b/test/run_test.sh index e8cbdd0634e55..fd778412e1e33 100755 --- a/test/run_test.sh +++ b/test/run_test.sh @@ -75,6 +75,11 @@ if [[ "$TEST_DISTRIBUTED" -eq 1 ]]; then BACKEND=tcp WORLD_SIZE=3 $PYCMD ./test_distributed.py distributed_tear_down + echo "Running distributed tests for the Gloo backend" + distributed_set_up + BACKEND=gloo WORLD_SIZE=3 $PYCMD ./test_distributed.py + distributed_tear_down + echo "Running distributed tests for the MPI backend" distributed_set_up BACKEND=mpi mpiexec -n 3 $PYCMD ./test_distributed.py diff --git a/test/test_distributed.py b/test/test_distributed.py index bed0b8dcf9b53..dcefaff7f0006 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -136,6 +136,8 @@ def test_send_recv(self): self._barrier() # SEND RECV ANY SOURCE + @unittest.skipIf(BACKEND == 'gloo', + "Gloo does not support send/recv from any source") def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, rank) @@ -229,12 +231,14 @@ def _test_reduce_helper(self, group, group_id, rank, op, master_value, worker_va self._barrier() + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_sum(self): group, group_id, rank = self._init_global_test() self._test_reduce_helper( group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1)) ) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_product(self): group, group_id, rank = self._init_global_test() self._test_reduce_helper( @@ -242,24 +246,28 @@ def test_reduce_product(self): 2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2) ) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_min(self): group, group_id, rank = self._init_global_test() self._test_reduce_helper( group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1 ) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_max(self): group, group_id, rank = self._init_global_test() self._test_reduce_helper( group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10 ) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_group_sum(self): group, group_id, rank = self._init_group_test() self._test_reduce_helper( group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1)) ) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_group_product(self): group, group_id, rank = self._init_group_test() self._test_reduce_helper( @@ -267,12 +275,14 @@ def test_reduce_group_product(self): 2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2) ) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_group_min(self): group, group_id, rank = self._init_group_test() self._test_reduce_helper( group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1 ) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce") def test_reduce_group_max(self): group, group_id, rank = self._init_group_test() self._test_reduce_helper( @@ -358,10 +368,12 @@ def _test_scatter_helper(self, group, group_id, rank): self._barrier() + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support scatter") def test_scatter(self): group, group_id, rank = self._init_global_test() self._test_scatter_helper(group, group_id, rank) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support scatter") def test_scatter_group(self): group, group_id, rank = self._init_group_test() self._test_scatter_helper(group, group_id, rank) @@ -382,10 +394,12 @@ def _test_gather_helper(self, group, group_id, rank): self._barrier() + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support gather") def test_gather(self): group, group_id, rank = self._init_global_test() self._test_gather_helper(group, group_id, rank) + @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support gather") def test_gather_group(self): group, group_id, rank = self._init_group_test() self._test_gather_helper(group, group_id, rank) @@ -437,10 +451,10 @@ def test_barrier_group(self): group, group_id, rank = self._init_group_test() self._test_barrier_helper(group, group_id, rank) -if BACKEND == 'tcp': +if BACKEND == 'tcp' or BACKEND == 'gloo': WORLD_SIZE = os.environ['WORLD_SIZE'] - class TestTCP(TestCase, _DistTestBase): + class TestTCPOrGloo(TestCase, _DistTestBase): MANAGER_PROCESS_RANK = -1 JOIN_TIMEOUT = 5 diff --git a/torch/lib/THD/base/data_channels/DataChannelGloo.cpp b/torch/lib/THD/base/data_channels/DataChannelGloo.cpp index 81bb9c706b195..20e7812105c7b 100644 --- a/torch/lib/THD/base/data_channels/DataChannelGloo.cpp +++ b/torch/lib/THD/base/data_channels/DataChannelGloo.cpp @@ -105,11 +105,16 @@ bool DataChannelGloo::init() { return true; } -DataChannelGloo::store_type DataChannelGloo::getStore() { - // This `id` \/ has to be consistent on all instances but also - // has to be different for each operation. - static std::uint64_t id = 0; // TODO: that's not a good solution - return ::gloo::rendezvous::PrefixStore(std::to_string(id++), *_store); + +template +store_type DataChannelGloo::getStore(THDGroup group_id, Args... args) { + std::string unique_prefix = std::to_string(op) + "-" + std::to_string(group_id); + std::vector v = {std::to_string(args)...}; + for (auto it = v.begin(); it != v.end(); ++it) { + unique_prefix += "-" + *it; + } + + return ::gloo::rendezvous::PrefixStore(unique_prefix, *_store); } @@ -126,13 +131,10 @@ rank_type DataChannelGloo::getNumProcesses() { template void DataChannelGloo::allGatherT(std::vector& output, thpp::Tensor& input, THDGroup group_id) { - auto store = getStore(); - RETURN_IF_NOT_IN_GROUP(_) - std::uint64_t tensor_bytes = input.elementSize() * input.numel(); std::uint64_t all_tensor_bytes = tensor_bytes * output.size(); auto ret = GlooCache::get().getAlgorithm( - group_id, _groups.at(group_id), store, + group_id, _groups.at(group_id), getStore(group_id), tensor_bytes, all_tensor_bytes, input.numel()); std::memcpy(std::get<1>(ret).get(), input.data(), tensor_bytes); @@ -147,6 +149,7 @@ void DataChannelGloo::allGatherT(std::vector& output, void DataChannelGloo::allGather(std::vector& output, thpp::Tensor& input, THDGroup group_id) { + RETURN_IF_NOT_IN_GROUP(_) GENERATE_ALL_TYPES(input.type(), allGatherT, output, input, group_id) } @@ -170,12 +173,9 @@ void DataChannelGloo::scatter(std::vector& input, template void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation, THDGroup group_id) { - auto store = getStore(); - RETURN_IF_NOT_IN_GROUP(_) - std::uint64_t tensor_bytes = t.elementSize() * t.numel(); auto ret = GlooCache::get().getAlgorithm( - group_id, _groups.at(group_id), store, + group_id, _groups.at(group_id), getStore(group_id), tensor_bytes, t.numel(), operation); std::memcpy(std::get<1>(ret).get(), t.data(), tensor_bytes); @@ -185,6 +185,7 @@ void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation, void DataChannelGloo::allReduce(thpp::Tensor& data, THDReduceOp operation, THDGroup group_id) { + RETURN_IF_NOT_IN_GROUP(_) GENERATE_ALL_TYPES(data.type(), allReduceT, data, operation, group_id) } @@ -199,12 +200,11 @@ void DataChannelGloo::reduce(thpp::Tensor& data, THDReduceOp operation, template void DataChannelGloo::broadcastT(thpp::Tensor& data, rank_type src_rank, THDGroup group_id) { - auto store = getStore(); RETURN_IF_NOT_IN_GROUP(group_rank) std::uint64_t tensor_bytes = data.elementSize() * data.numel(); auto ret = GlooCache::get().getAlgorithm( - group_id, _groups.at(group_id), store, + group_id, _groups.at(group_id), getStore(group_id), tensor_bytes, data.numel(), src_rank); if (group_rank == src_rank) @@ -279,11 +279,10 @@ auto DataChannelGloo::ireceive(thpp::Tensor& data, rank_type src_rank) -> Reques void DataChannelGloo::barrier(THDGroup group_id) { - auto store = getStore(); RETURN_IF_NOT_IN_GROUP(_) auto ret = GlooCache::get().getAlgorithm( - group_id, _groups.at(group_id), store); + group_id, _groups.at(group_id), getStore(group_id)); std::get<0>(ret)->run(); } @@ -297,11 +296,10 @@ THDGroup DataChannelGloo::newGroup(const std::vector& ranks) { void DataChannelGloo::_send(const Scalar& data, rank_type dst_rank) { - auto store = getStore(); std::unique_ptr data_copy(data.clone()); auto ctx = GlooCache::get().getSharedContext( _groups.at(THDGroupWORLD), - store + getStore(THDGroupWORLD) ); auto& pair = ctx->getPair(dst_rank); pair->createSendBuffer(ctx->nextSlot(), data_copy->data(), data_copy->elementSize())->waitSend(); @@ -309,10 +307,9 @@ void DataChannelGloo::_send(const Scalar& data, rank_type dst_rank) { void DataChannelGloo::_send(thpp::Tensor& data, rank_type dst_rank) { - auto store = getStore(); auto ctx = GlooCache::get().getSharedContext( _groups.at(THDGroupWORLD), - store + getStore(THDGroupWORLD) ); auto& pair = ctx->getPair(dst_rank); uint64_t tensor_bytes = data.elementSize() * data.numel(); @@ -321,10 +318,9 @@ void DataChannelGloo::_send(thpp::Tensor& data, rank_type dst_rank) { void DataChannelGloo::_receive(Scalar& data, rank_type src_rank) { - auto store = getStore(); auto ctx = GlooCache::get().getSharedContext( _groups.at(THDGroupWORLD), - store + getStore(THDGroupWORLD) ); auto& pair = ctx->getPair(src_rank); pair->createRecvBuffer(ctx->nextSlot(), data.data(), data.elementSize())->waitRecv(); @@ -332,10 +328,9 @@ void DataChannelGloo::_receive(Scalar& data, rank_type src_rank) { void DataChannelGloo::_receive(thpp::Tensor& data, rank_type src_rank) { - auto store = getStore(); auto ctx = GlooCache::get().getSharedContext( _groups.at(THDGroupWORLD), - store + getStore(THDGroupWORLD) ); auto& pair = ctx->getPair(src_rank); uint64_t tensor_bytes = data.elementSize() * data.numel(); diff --git a/torch/lib/THD/base/data_channels/DataChannelGloo.hpp b/torch/lib/THD/base/data_channels/DataChannelGloo.hpp index 4e66de7987aeb..f932c3d81bab5 100644 --- a/torch/lib/THD/base/data_channels/DataChannelGloo.hpp +++ b/torch/lib/THD/base/data_channels/DataChannelGloo.hpp @@ -36,7 +36,7 @@ struct hash()(std::get<0>(k)) ^ hash()(std::get<1>(k)) ^ - hash()(std::get<2>(k)) ^ + hash()(std::get<2>(k)) ^ hash()(std::get<3>(k)) ^ hash()(std::get<4>(k)) ); @@ -108,7 +108,8 @@ struct DataChannelGloo : DataChannel { void broadcastT(thpp::Tensor& data, rank_type src_rank, THDGroup group_id = THDGroupWORLD); - store_type getStore(); + template + store_type getStore(DataOperation op, THDGroup group_id, Args... args); void _send(const Scalar& data, rank_type dst_id); void _send(thpp::Tensor& data, rank_type dst_id);