Skip to content

Commit

Permalink
[Runtime] Modify _RefSend and _RefRecv node name. (DeepRec-AI#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
shanshanpt authored Dec 8, 2021
1 parent a8adcc9 commit c1622e2
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tensorflow/core/graph/star_server_graph_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,13 @@ Status GraphPartitionerBase::ConstructRecvNodeDef(
bool client_terminated,
NodeDef *node_def)
{
std::string real_node_name(node_name);
std::string recv_name("_Recv");
if (IsRefType(tensor_type)) recv_name = "_RefRecv";
return NodeDefBuilder(node_name, recv_name).Device(recv_device_name)
if (IsRefType(tensor_type)) {
recv_name = "_RefRecv";
real_node_name = "_ref_" + node_name;
}
return NodeDefBuilder(real_node_name, recv_name).Device(recv_device_name)
.Attr("tensor_type", IsRefType(tensor_type) ?
RemoveRefType(tensor_type) : tensor_type)
.Attr("tensor_name", tensor_name)
Expand All @@ -275,9 +279,13 @@ Status GraphPartitionerBase::ConstructSendNodeDef(
bool client_terminated,
NodeDef *node_def)
{
std::string real_node_name(node_name);
std::string send_name("_Send");
if (IsRefType(tensor_type)) send_name = "_RefSend";
return NodeDefBuilder(node_name, send_name).Device(send_device_name)
if (IsRefType(tensor_type)) {
send_name = "_RefSend";
real_node_name = "_ref_" + node_name;
}
return NodeDefBuilder(real_node_name, send_name).Device(send_device_name)
.Input(input_node_name, input_idx, tensor_type)
.Attr("T", IsRefType(tensor_type) ?
RemoveRefType(tensor_type) : tensor_type)
Expand Down

0 comments on commit c1622e2

Please sign in to comment.