Skip to content

Commit

Permalink
Merge pull request #14290 from Mellanox/master
Browse files Browse the repository at this point in the history
Adding connectivity check, compilation fix and some code refactoring to verbs
  • Loading branch information
jhseu authored Nov 29, 2017
2 parents 4b7d79b + d43d00b commit 4905969
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 40 deletions.
6 changes: 4 additions & 2 deletions tensorflow/contrib/verbs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package(default_visibility = [

licenses(["notice"]) # Apache 2.0

load("//tensorflow:tensorflow.bzl", "tf_cuda_library")

exports_files(["LICENSE"])

filegroup(
Expand Down Expand Up @@ -97,7 +99,7 @@ cc_library(
alwayslink = 1,
)

cc_library(
tf_cuda_library(
name = "rdma_rendezvous_mgr",
srcs = ["rdma_rendezvous_mgr.cc"],
hdrs = ["rdma_rendezvous_mgr.h"],
Expand Down Expand Up @@ -130,7 +132,7 @@ cc_library(
],
)

cc_library(
tf_cuda_library(
name = "rdma",
srcs = ["rdma.cc"],
hdrs = ["rdma.h"],
Expand Down
57 changes: 49 additions & 8 deletions tensorflow/contrib/verbs/rdma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ limitations under the License.
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#endif
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
Expand All @@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/core/threadpool.h"

namespace tensorflow {

Expand Down Expand Up @@ -419,9 +422,6 @@ RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
0);
CHECK(cq_) << "Failed to create completion queue";
CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
polling_thread_.reset(Env::Default()->StartThread(
ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
VLOG(2) << "Start RdmaAdapter: " << name();
}

RdmaAdapter::~RdmaAdapter() {
Expand All @@ -433,6 +433,12 @@ RdmaAdapter::~RdmaAdapter() {
CHECK(!ibv_close_device(context_)) << "Failed to release context";
}

void RdmaAdapter::StartPolling() {
polling_thread_.reset(Env::Default()->StartThread(
ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
VLOG(2) << "Start RdmaAdapter: " << name();
}

string RdmaAdapter::name() const { return string(context_->device->name); }

// Function to process incoming messages
Expand Down Expand Up @@ -558,9 +564,44 @@ void RdmaAdapter::Process_CQ() {
}
}

int RdmaChannel::PingPostRecv() {
struct ibv_recv_wr wr, *bad_wr;
memset(&wr, 0, sizeof(wr));
wr.sg_list = &ping_sge_list_;
wr.num_sge = 1;
wr.wr_id = kPingRecvWrid;

return ibv_post_recv(qp_, &wr, &bad_wr);
}

int RdmaChannel::PingPostSend() {
struct ibv_send_wr wr, *bad_wr;
memset(&wr, 0, sizeof(wr));
wr.wr_id = (uint64_t) this;
wr.sg_list = &ping_sge_list_;
wr.num_sge = 1;
wr.opcode = IBV_WR_SEND;
wr.send_flags = IBV_SEND_SIGNALED;

return ibv_post_send(qp_, &wr, &bad_wr);
}

RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
const string remote_name)
: adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {

struct ibv_sge list;

mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize,
IBV_ACCESS_LOCAL_WRITE);
CHECK(mr_) << "Failed to register memory region";

memset(&list, 0, sizeof(list));
list.addr = (uintptr_t)ping_buff_;
list.length = kPingBuffSize;
list.lkey = mr_->lkey;

ping_sge_list_ = list;
// Create queue pair
{
struct ibv_qp_init_attr attr;
Expand Down Expand Up @@ -633,15 +674,13 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
buffer_index_name_table_.insert({index, buffer_names[i]});
buffer_name_index_table_.insert({buffer_names[i], index});
}

// Initiate recv
for (int i = 0; i < 100; i++) {
Recv();
}
}
CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_
<< " with error " << std::strerror(errno);
}

RdmaChannel::~RdmaChannel() {
ibv_dereg_mr(mr_);
CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
delete tx_message_buffer_;
delete rx_message_buffer_;
Expand Down Expand Up @@ -1026,6 +1065,7 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
#if GOOGLE_CUDA
CHECK(send_args.device_context) << "send dev name: " << src_dev->name()
<< " gpu_info: "
<< src_dev->tensorflow_gpu_device_info();
Expand Down Expand Up @@ -1064,6 +1104,7 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
&proto, NULL, send_args, recv_args);
});
}
#endif // GOOGLE_CUDA
} else {
// tensor is in CPU memory.
StringPiece copy_buf;
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/contrib/verbs/rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class RdmaAdapter {
~RdmaAdapter();
// Adapter name, e.g. mlx5_0.
string name() const;
void StartPolling();
void Process_CQ();

protected:
Expand Down Expand Up @@ -161,6 +162,15 @@ class RdmaChannel {
void RemoveRecvCallback(const string& key);
void RunRecvCallback(const string& key);
static const int kNumMessageBuffers = 4;
static const int kPingRecvWrid = 0;

private:
static const int kPingBuffSize = 1024;
char ping_buff_[kPingBuffSize];
struct ibv_mr* mr_;
struct ibv_sge ping_sge_list_;
int PingPostRecv();
int PingPostSend();

protected:
const RdmaAdapter* adapter_;
Expand Down
51 changes: 51 additions & 0 deletions tensorflow/contrib/verbs/rdma_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,57 @@ void RdmaMgr::SetupChannels() {
}
}

// Check connectivity by pinging every channel
bool RdmaMgr::ConnectivityCheck() {
int i, rcnt = 0, scnt = 0;

for (const auto& p : channel_table_) {
string worker_name = p.first;
RdmaChannel* rc = p.second;

VLOG(2) << "Ping to " << worker_name;
CHECK(rc->PingPostSend() == 0) << "Couldn't post send to " << worker_name
<< " with error: " << std::strerror(errno);
for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) {
rc->Recv();
}
}

while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) {
int ne;
do {
ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_,
rdma_adapter_->wc_);
CHECK(ne >= 0) << "poll CQ failed " << ne << "with error"
<< std::strerror(errno);
} while (ne < 1);

for (i = 0; i < ne; ++i) {
ibv_wc_status s = rdma_adapter_->wc_[i].status;
// recv complete
if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
rdma_adapter_->wc_[i].status)
<< "(" << rdma_adapter_->wc_[i].status
<< ") for PING_RECV_WRID";
++rcnt;
// send complete
} else {
RdmaChannel* rc =
reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
rdma_adapter_->wc_[i].status)
<< "(" << rdma_adapter_->wc_[i].status
<< ") to " << rc->remote_name_;
++scnt;
}
} // for
} // while
CHECK(rcnt == scnt) << "Connectivity check failed!";
rdma_adapter_->StartPolling();
return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt);
}

RdmaMgr::~RdmaMgr() {
for (const auto& p : channel_table_) delete p.second;
channel_table_.clear();
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/contrib/verbs/rdma_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ limitations under the License.
namespace tensorflow {

class RdmaMgr {
friend class RdmaChannel;
friend class RdmaAdapter;

public:
explicit RdmaMgr(const WorkerEnv* const worker_env,
GrpcChannelCache* const channel_cache);
~RdmaMgr();
RdmaChannel* FindChannel(const string& key);
void SetupChannels();
bool ConnectivityCheck();
const string& local_worker() { return local_worker_; }

private:
Expand All @@ -44,7 +48,6 @@ class RdmaMgr {
RdmaAdapter* rdma_adapter_;
typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
ChannelTable channel_table_;

TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
};

Expand Down
46 changes: 19 additions & 27 deletions tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#endif // GOOGLE_CUDA
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
Expand Down Expand Up @@ -58,20 +60,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
// parse src_name and dst_name
string src_name, dst_name, unused;
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
&unused) ||
!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
&unused)) {
s = errors::Internal("Could not parse src name.");
s = errors::Internal("Could not parse src or dst name.");
}
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
done(s, Args(), recv_args, Tensor{}, false);
return;
}
if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
&unused)) {
s = errors::Internal("Could not parse dst name.");
}
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
LOG(ERROR) << "s is not ok, error code " << s.error_message();
done(s, Args(), recv_args, Tensor{}, false);
return;
}
Expand All @@ -82,18 +77,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
// insert callback
rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
recv_args, parsed, done]() {
Status s;
Device* src_dev;
s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), true);
return;
}
Device* dst_dev;
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
Status src_s, dst_s, s;
Device* src_dev, *dst_dev;
src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
if (!src_s.ok() || !dst_s.ok()) {
s = src_s.ok() ? dst_s : src_s;
LOG(ERROR) << "s is not ok, error code " << s.error_message();
done(s, Args(), recv_args, Tensor(), true);
return;
}
Expand All @@ -110,9 +100,10 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
if (can_memcpy) {
if (dst_dev->tensorflow_gpu_device_info() &&
(!recv_args.alloc_attrs.on_host())) {
#if GOOGLE_CUDA
CHECK(recv_args.device_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
Expand All @@ -122,14 +113,15 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(

GPUUtil::CopyCPUTensorToGPU(
&copy, recv_args.device_context, dst_dev, &gpu_copy,
[this, gpu_copy, key, key_with_step_id, recv_args, done, rm,
rc](const Status& s) {
[this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc](
const Status& s) {
CHECK(s.ok()) << "copy tensor to gpu sync";
Tensor val;
val = std::move(gpu_copy);
RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc,
val, s);
});
#endif // GOOGLE_CUDA
return;
} else {
AllocatorAttributes host_alloc_attrs;
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/contrib/verbs/verbs_server_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ VerbsServer::~VerbsServer() {
Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
GrpcChannelCache** channel_cache) {
string name_prefix =
strings::StrCat("/job:", server_def.job_name(), "/replica:0",
"/task:", server_def.task_index());
strings::StrCat("/job:", server_def.job_name(), "/replica:0", "/task:",
server_def.task_index());

GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
Expand Down Expand Up @@ -103,6 +103,7 @@ Status VerbsServer::Start() {
ThreadOptions(), "TF_verbs_service",
[this] { verbs_service_->HandleRPCsLoop(); }));
rdma_mgr_->SetupChannels();
CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!";
verbs_state_ = CONNECTED;
}
}
Expand Down

0 comments on commit 4905969

Please sign in to comment.