Skip to content

Commit

Permalink
Update comments; Add inline accessors for value_type tuple in GlooCache
Browse files Browse the repository at this point in the history
  • Loading branch information
VirrageS authored and apaszke committed May 1, 2017
1 parent a17d96d commit c19fbd3
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 21 deletions.
28 changes: 14 additions & 14 deletions torch/lib/THD/base/data_channels/DataChannelGloo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,16 @@ void DataChannelGloo::allGatherT(std::vector<thpp::Tensor*>& output,
auto ret = _cache->getAlgorithm<CollectiveType::ALL_GATHER, T>(
group_id, _groups.at(group_id), tensor_bytes, all_tensor_bytes, input.numel());

std::memcpy(std::get<1>(ret).get(), input.data(), tensor_bytes);
std::memcpy(GlooCache::input_buffer(ret).get(), input.data(), tensor_bytes);

{
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
std::get<0>(ret)->run();
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
GlooCache::algorithm(ret)->run();
}

for (std::size_t i = 0; i < output.size(); i++) {
std::memcpy(output.at(i)->data(),
std::get<2>(ret).get() + (i * tensor_bytes),
GlooCache::output_buffer(ret).get() + (i * tensor_bytes),
tensor_bytes);
}
}
Expand Down Expand Up @@ -188,12 +188,12 @@ void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation,
auto ret = _cache->getAlgorithm<CollectiveType::ALL_REDUCE, T>(
group_id, _groups.at(group_id), tensor_bytes, t.numel(), operation);

std::memcpy(std::get<1>(ret).get(), t.data(), tensor_bytes);
std::memcpy(GlooCache::input_buffer(ret).get(), t.data(), tensor_bytes);
{
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
std::get<0>(ret)->run();
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
GlooCache::algorithm(ret)->run();
}
std::memcpy(t.data(), std::get<2>(ret).get(), tensor_bytes);
std::memcpy(t.data(), GlooCache::output_buffer(ret).get(), tensor_bytes);
}

void DataChannelGloo::allReduce(thpp::Tensor& data, THDReduceOp operation,
Expand All @@ -219,15 +219,15 @@ void DataChannelGloo::broadcastT(thpp::Tensor& data, rank_type src_rank,
_groups.at(group_id).mustGetGroupRank(src_rank));

if (_rank == src_rank)
std::memcpy(std::get<1>(ret).get(), data.data(), tensor_bytes);
std::memcpy(GlooCache::input_buffer(ret).get(), data.data(), tensor_bytes);

{
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
std::get<0>(ret)->run();
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
GlooCache::algorithm(ret)->run();
}

if (_rank != src_rank)
std::memcpy(data.data(), std::get<2>(ret).get(), tensor_bytes);
std::memcpy(data.data(), GlooCache::output_buffer(ret).get(), tensor_bytes);
}


Expand Down Expand Up @@ -278,8 +278,8 @@ void DataChannelGloo::barrier(THDGroup group_id) {
auto ret = _cache->getAlgorithm<CollectiveType::BARRIER, void>(
group_id, _groups.at(group_id));
{
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
std::get<0>(ret)->run();
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
GlooCache::algorithm(ret)->run();
}
}

Expand Down
10 changes: 5 additions & 5 deletions torch/lib/THD/base/data_channels/DataChannelMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ DataChannelMPI::~DataChannelMPI() {


bool DataChannelMPI::init() {
int* provided = NULL;
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, provided);
if (*provided != MPI_THREAD_MULTIPLE) {
std::cerr << "MPI implementation does not support multiple threads."
<< "Using same data channel in multiple thread can result in"
int provided;
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided);
if (provided != MPI_THREAD_MULTIPLE) {
std::cerr << "WARNING: MPI implementation does not support multiple threads. "
<< "Using same data channel in multiple thread can result in "
<< "wrong results or errors." << std::endl;
}

Expand Down
3 changes: 2 additions & 1 deletion torch/lib/THD/base/data_channels/DataChannelTCP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ struct DataChannelTCP : DataChannel {
std::vector<Process> _processes; // Other processes in network
std::unique_ptr<struct pollfd[]> _poll_events; // Events array for `poll`

std::mutex _mutex; // General mutex for methods - to make methods run atomically.
// General mutex for methods - to protect access to the TCP data channel.
std::mutex _mutex;

// Existing groups of processes and corresponding group ids
std::unordered_map<THDGroup, DataChannel::Group> _groups;
Expand Down
21 changes: 20 additions & 1 deletion torch/lib/THD/base/data_channels/GlooCache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct GlooCache {
std::shared_ptr<algorithm_type>, // algorithm
std::shared_ptr<buffer_type>, // input buffer (nullptr if not used)
std::shared_ptr<buffer_type>, // output buffer (nullptr if not used)
std::shared_ptr<std::mutex> // mutex to make algorithms run atomically
std::shared_ptr<std::mutex> // mutex to protect same algorithm from running concurrently
>;

GlooCache(rank_type rank, std::shared_ptr<::gloo::transport::Device> device,
Expand All @@ -72,6 +72,25 @@ struct GlooCache {
GlooCache(GlooCache const&) = delete;
void operator=(GlooCache const&) = delete;


// Accessors for value_type tuple
static inline std::shared_ptr<algorithm_type> algorithm(const value_type& t) {
return std::get<0>(t);
}

static inline std::shared_ptr<buffer_type> input_buffer(const value_type& t) {
return std::get<1>(t);
}

static inline std::shared_ptr<buffer_type> output_buffer(const value_type& t) {
return std::get<2>(t);
}

static inline std::shared_ptr<std::mutex> mutex(const value_type& t) {
return std::get<3>(t);
}


std::shared_ptr<context_type> createContext(
const DataChannel::Group& group,
prefix_store_type& store
Expand Down
6 changes: 6 additions & 0 deletions torch/lib/THD/test/data_channel_collectives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ constexpr int BARRIER_WAIT_TIME = 200; // milliseconds
std::vector<std::thread> g_all_workers;
std::mutex g_mutex;
std::string g_data_channel_type;
std::unique_ptr<Barrier> g_barrier;


void test_send_recv_tensor(std::shared_ptr<thd::DataChannel> data_channel) {
Expand Down Expand Up @@ -684,6 +685,8 @@ void init_gloo_master(int workers) {

assert(masterChannel->init());
run_all_tests(masterChannel, workers);

g_barrier->wait();
}

void init_gloo_worker(unsigned int id, int workers) {
Expand All @@ -695,6 +698,8 @@ void init_gloo_worker(unsigned int id, int workers) {

assert(worker_channel->init());
run_all_tests(worker_channel, workers);

g_barrier->wait();
}
#endif // WITH_GLOO

Expand Down Expand Up @@ -733,6 +738,7 @@ int main(int argc, char const *argv[]) {
#ifdef WITH_GLOO
g_data_channel_type = "gloo";
for (auto workers : WORKERS_NUM) {
g_barrier.reset(new Barrier(workers + 1));
std::cout << "Gloo (workers: " << workers << "):" << std::endl;
// start gloo master
std::thread gloo_master_thread(init_gloo_master, workers);
Expand Down

0 comments on commit c19fbd3

Please sign in to comment.