Skip to content

Commit

Permalink
Add python tests; Remove broken prefix store creation
Browse files Browse the repository at this point in the history
VirrageS authored and apaszke committed May 1, 2017
1 parent 6888c61 commit 2b340e7
Showing 4 changed files with 44 additions and 29 deletions.
5 changes: 5 additions & 0 deletions test/run_test.sh
Original file line number Diff line number Diff line change
@@ -75,6 +75,11 @@ if [[ "$TEST_DISTRIBUTED" -eq 1 ]]; then
BACKEND=tcp WORLD_SIZE=3 $PYCMD ./test_distributed.py
distributed_tear_down

echo "Running distributed tests for the Gloo backend"
distributed_set_up
BACKEND=gloo WORLD_SIZE=3 $PYCMD ./test_distributed.py
distributed_tear_down

echo "Running distributed tests for the MPI backend"
distributed_set_up
BACKEND=mpi mpiexec -n 3 $PYCMD ./test_distributed.py
18 changes: 16 additions & 2 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -136,6 +136,8 @@ def test_send_recv(self):
self._barrier()

# SEND RECV ANY SOURCE
@unittest.skipIf(BACKEND == 'gloo',
"Gloo does not support send/recv from any source")
def test_send_recv_any_source(self):
rank = dist.get_rank()
tensor = _build_tensor(10, rank)
@@ -229,50 +231,58 @@ def _test_reduce_helper(self, group, group_id, rank, op, master_value, worker_va

self._barrier()

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_sum(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_product(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.PRODUCT,
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_min(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_max(self):
group, group_id, rank = self._init_global_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_group_sum(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.SUM, 2, 10, 2 + (10 * (len(group) - 1))
)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_group_product(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.PRODUCT,
2, 10, reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2)
)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_group_min(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support reduce")
def test_reduce_group_max(self):
group, group_id, rank = self._init_group_test()
self._test_reduce_helper(
@@ -358,10 +368,12 @@ def _test_scatter_helper(self, group, group_id, rank):

self._barrier()

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support scatter")
def test_scatter(self):
group, group_id, rank = self._init_global_test()
self._test_scatter_helper(group, group_id, rank)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support scatter")
def test_scatter_group(self):
group, group_id, rank = self._init_group_test()
self._test_scatter_helper(group, group_id, rank)
@@ -382,10 +394,12 @@ def _test_gather_helper(self, group, group_id, rank):

self._barrier()

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support gather")
def test_gather(self):
group, group_id, rank = self._init_global_test()
self._test_gather_helper(group, group_id, rank)

@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support gather")
def test_gather_group(self):
group, group_id, rank = self._init_group_test()
self._test_gather_helper(group, group_id, rank)
@@ -437,10 +451,10 @@ def test_barrier_group(self):
group, group_id, rank = self._init_group_test()
self._test_barrier_helper(group, group_id, rank)

if BACKEND == 'tcp':
if BACKEND == 'tcp' or BACKEND == 'gloo':
WORLD_SIZE = os.environ['WORLD_SIZE']

class TestTCP(TestCase, _DistTestBase):
class TestTCPOrGloo(TestCase, _DistTestBase):

MANAGER_PROCESS_RANK = -1
JOIN_TIMEOUT = 5
45 changes: 20 additions & 25 deletions torch/lib/THD/base/data_channels/DataChannelGloo.cpp
Original file line number Diff line number Diff line change
@@ -105,11 +105,16 @@ bool DataChannelGloo::init() {
return true;
}

DataChannelGloo::store_type DataChannelGloo::getStore() {
// This `id` \/ has to be consistent on all instances but also
// has to be different for each operation.
static std::uint64_t id = 0; // TODO: that's not a good solution
return ::gloo::rendezvous::PrefixStore(std::to_string(id++), *_store);

template<DataOperation op, typename... Args>
store_type DataChannelGloo::getStore(THDGroup group_id, Args... args) {
std::string unique_prefix = std::to_string(op) + "-" + std::to_string(group_id);
std::vector<std::string> v = {std::to_string(args)...};
for (auto it = v.begin(); it != v.end(); ++it) {
unique_prefix += "-" + *it;
}

return ::gloo::rendezvous::PrefixStore(unique_prefix, *_store);
}


@@ -126,13 +131,10 @@ rank_type DataChannelGloo::getNumProcesses() {
template<typename T>
void DataChannelGloo::allGatherT(std::vector<thpp::Tensor*>& output,
thpp::Tensor& input, THDGroup group_id) {
auto store = getStore();
RETURN_IF_NOT_IN_GROUP(_)

std::uint64_t tensor_bytes = input.elementSize() * input.numel();
std::uint64_t all_tensor_bytes = tensor_bytes * output.size();
auto ret = GlooCache::get().getAlgorithm<DataOperation::ALL_GATHER, T>(
group_id, _groups.at(group_id), store,
group_id, _groups.at(group_id), getStore<DataOperation::ALL_GATHER>(group_id),
tensor_bytes, all_tensor_bytes, input.numel());

std::memcpy(std::get<1>(ret).get(), input.data(), tensor_bytes);
@@ -147,6 +149,7 @@ void DataChannelGloo::allGatherT(std::vector<thpp::Tensor*>& output,

void DataChannelGloo::allGather(std::vector<thpp::Tensor*>& output,
thpp::Tensor& input, THDGroup group_id) {
RETURN_IF_NOT_IN_GROUP(_)
GENERATE_ALL_TYPES(input.type(), allGatherT, output, input, group_id)
}

@@ -170,12 +173,9 @@ void DataChannelGloo::scatter(std::vector<thpp::Tensor*>& input,
template<typename T>
void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation,
THDGroup group_id) {
auto store = getStore();
RETURN_IF_NOT_IN_GROUP(_)

std::uint64_t tensor_bytes = t.elementSize() * t.numel();
auto ret = GlooCache::get().getAlgorithm<DataOperation::ALL_REDUCE, T>(
group_id, _groups.at(group_id), store,
group_id, _groups.at(group_id), getStore<DataOperation::ALL_REDUCE>(group_id),
tensor_bytes, t.numel(), operation);

std::memcpy(std::get<1>(ret).get(), t.data(), tensor_bytes);
@@ -185,6 +185,7 @@ void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation,

void DataChannelGloo::allReduce(thpp::Tensor& data, THDReduceOp operation,
THDGroup group_id) {
RETURN_IF_NOT_IN_GROUP(_)
GENERATE_ALL_TYPES(data.type(), allReduceT, data, operation, group_id)
}

@@ -199,12 +200,11 @@ void DataChannelGloo::reduce(thpp::Tensor& data, THDReduceOp operation,
template<typename T>
void DataChannelGloo::broadcastT(thpp::Tensor& data, rank_type src_rank,
THDGroup group_id) {
auto store = getStore();
RETURN_IF_NOT_IN_GROUP(group_rank)

std::uint64_t tensor_bytes = data.elementSize() * data.numel();
auto ret = GlooCache::get().getAlgorithm<DataOperation::BROADCAST, T>(
group_id, _groups.at(group_id), store,
group_id, _groups.at(group_id), getStore<DataOperation::BROADCAST>(group_id),
tensor_bytes, data.numel(), src_rank);

if (group_rank == src_rank)
@@ -279,11 +279,10 @@ auto DataChannelGloo::ireceive(thpp::Tensor& data, rank_type src_rank) -> Reques


void DataChannelGloo::barrier(THDGroup group_id) {
auto store = getStore();
RETURN_IF_NOT_IN_GROUP(_)

auto ret = GlooCache::get().getAlgorithm<DataOperation::BARRIER, void>(
group_id, _groups.at(group_id), store);
group_id, _groups.at(group_id), getStore<DataOperation::BARRIER>(group_id));
std::get<0>(ret)->run();
}

@@ -297,22 +296,20 @@ THDGroup DataChannelGloo::newGroup(const std::vector<rank_type>& ranks) {


void DataChannelGloo::_send(const Scalar& data, rank_type dst_rank) {
auto store = getStore();
std::unique_ptr<Scalar> data_copy(data.clone());
auto ctx = GlooCache::get().getSharedContext<DataOperation::SEND>(
_groups.at(THDGroupWORLD),
store
getStore<DataOperation::SEND>(THDGroupWORLD)
);
auto& pair = ctx->getPair(dst_rank);
pair->createSendBuffer(ctx->nextSlot(), data_copy->data(), data_copy->elementSize())->waitSend();
}


void DataChannelGloo::_send(thpp::Tensor& data, rank_type dst_rank) {
auto store = getStore();
auto ctx = GlooCache::get().getSharedContext<DataOperation::SEND>(
_groups.at(THDGroupWORLD),
store
getStore<DataOperation::SEND>(THDGroupWORLD)
);
auto& pair = ctx->getPair(dst_rank);
uint64_t tensor_bytes = data.elementSize() * data.numel();
@@ -321,21 +318,19 @@ void DataChannelGloo::_send(thpp::Tensor& data, rank_type dst_rank) {


void DataChannelGloo::_receive(Scalar& data, rank_type src_rank) {
auto store = getStore();
auto ctx = GlooCache::get().getSharedContext<DataOperation::SEND>(
_groups.at(THDGroupWORLD),
store
getStore<DataOperation::SEND>(THDGroupWORLD)
);
auto& pair = ctx->getPair(src_rank);
pair->createRecvBuffer(ctx->nextSlot(), data.data(), data.elementSize())->waitRecv();
}


void DataChannelGloo::_receive(thpp::Tensor& data, rank_type src_rank) {
auto store = getStore();
auto ctx = GlooCache::get().getSharedContext<DataOperation::SEND>(
_groups.at(THDGroupWORLD),
store
getStore<DataOperation::SEND>(THDGroupWORLD)
);
auto& pair = ctx->getPair(src_rank);
uint64_t tensor_bytes = data.elementSize() * data.numel();
5 changes: 3 additions & 2 deletions torch/lib/THD/base/data_channels/DataChannelGloo.hpp
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ struct hash<std::tuple<::thd::DataOperation, THDGroup, std::size_t, std::size_t,
return (
hash<::thd::DataOperation>()(std::get<0>(k)) ^
hash<THDGroup>()(std::get<1>(k)) ^
hash<std::size_t>()(std::get<2>(k)) ^
hash<std::size_t>()(std::get<2>(k)) ^
hash<std::size_t>()(std::get<3>(k)) ^
hash<THDReduceOp>()(std::get<4>(k))
);
@@ -108,7 +108,8 @@ struct DataChannelGloo : DataChannel {
void broadcastT(thpp::Tensor& data, rank_type src_rank,
THDGroup group_id = THDGroupWORLD);

store_type getStore();
template<DataOperation op, typename... Args>
store_type getStore(DataOperation op, THDGroup group_id, Args... args);

void _send(const Scalar& data, rank_type dst_id);
void _send(thpp::Tensor& data, rank_type dst_id);

0 comments on commit 2b340e7

Please sign in to comment.