Skip to content

Commit

Permalink
Address comments from #3823 and polish code (#3964)
Browse files Browse the repository at this point in the history
* Address comments from #3823 and polish code

* One line
  • Loading branch information
wschin authored May 17, 2020
1 parent 4ff73d0 commit 0d11649
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 72 deletions.
163 changes: 92 additions & 71 deletions orttraining/orttraining/core/graph/pipeline_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ NodeArg& CreateInt64NodeArg(Graph& graph, const std::string& name) {
return node_arg;
}

void AddInputEvent(Graph& graph, const std::string& op_name,
bool is_forward,
void AddInputEvent(Graph& graph,
const std::string& event_name,
std::vector<NodeArg*>& input_args,
std::vector<std::string>& new_input_names) {
auto& event_id = CreateInt64NodeArg(graph, op_name + (is_forward ? "_fw" : "_bw") + "_event_id");
auto& event_id = CreateInt64NodeArg(graph, event_name);
new_input_names.push_back(event_id.Name());
input_args.push_back(&event_id);
}
Expand Down Expand Up @@ -91,7 +91,7 @@ std::vector<NodeArg*> CreateMirrorNodeArgs(

// Create a node with input schema [event, input1, input2, ..., inputN] and
// output schema [input1, input2, ..., inputN]
void CreateBottleneckNode(Graph& graph,
Node& CreateBottleneckNode(Graph& graph,
const std::string& op_type,
const std::string& op_name,
const std::string& description,
Expand All @@ -102,7 +102,8 @@ void CreateBottleneckNode(Graph& graph,
if (event) {
input_node_args.insert(input_node_args.begin(), event);
}
graph.AddNode(

return graph.AddNode(
name,
op_type,
description,
Expand All @@ -112,14 +113,14 @@ void CreateBottleneckNode(Graph& graph,
kMSDomain);
}

Node* AddRecordBackward(Graph& graph,
Node* AddBackwardRecord(Graph& graph,
Node* backward_send,
std::vector<std::string>& new_input_names,
std::vector<std::string>& new_output_names,
std::string &event_id_tensor_name,
std::string &output_tensor_name) {
std::vector<NodeArg*> input_args;
AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names);
AddInputEvent(graph, "backward_recorded_event_id", input_args, new_input_names);
std::vector<NodeArg*> output_args{};

if (backward_send) {
Expand All @@ -138,14 +139,9 @@ Node* AddRecordBackward(Graph& graph,
output_args.push_back(&new_output);
new_output_names.push_back(new_output.Name());

Node* record_node = &(graph.AddNode(
graph.GenerateNodeName("RecordEvent"),
"RecordEvent",
"Backward pass",
input_args,
output_args,
nullptr,
kMSDomain));
Node* record_node = &CreateBottleneckNode(
graph, "RecordEvent", "backward_record", "Backward pass", nullptr,
input_args, output_args);

// First input argument is the recorded event ID tensor.
event_id_tensor_name = input_args.front()->Name();
Expand All @@ -156,7 +152,7 @@ Node* AddRecordBackward(Graph& graph,
return record_node;
}

Node* AddWaitForward(Graph& graph,
Node* AddForwardWait(Graph& graph,
Node* /* forward_recv */,
std::vector<std::string>& new_input_names,
std::string& forward_waited_event_name,
Expand All @@ -179,7 +175,7 @@ Node* AddWaitForward(Graph& graph,

std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);
AddInputEvent(graph, "forward_waited_event_id", input_args, new_input_names);
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();

if (graph_inputs.size() == 0){
Expand All @@ -199,20 +195,20 @@ Node* AddWaitForward(Graph& graph,
}
}
}
Node* wait_node = &(graph.AddNode(
graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"",
input_args,
output_args,
nullptr,
kMSDomain));

Node* wait_node = &CreateBottleneckNode(
graph, "WaitEvent", "backward_record", "", nullptr,
input_args, output_args);

forward_waited_event_name = input_args.front()->Name();
output_tensor_name = output_args.front()->Name();

return wait_node;
}

Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
// If the input "graph" is the last pipeline stage, this function don't add any
// event operators.
Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
Node* forward_send,
Node* backward_recv,
std::vector<std::string>& new_input_names,
Expand All @@ -227,19 +223,21 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
}

if (!forward_send && !backward_recv){
// Last partition doesn't have send forwrad and recv backward. No insert needed.
// Last partition doesn't have send forwrad and recv backward. No insert
// needed.
return Status::OK();
}

// if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between.
// if we have a send forward op followed by a recv backward op, insert
// WaitEvent and RecordEvent in between.
Node* record_node = nullptr;
Node* wait_node = nullptr;

// Insert RecordEvent
{
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names);
AddInputEvent(graph, "forward_recorded_event_id", input_args, new_input_names);

// Add send forward op's output as record op's input and output
for (auto& output : forward_send->MutableOutputDefs()) {
Expand All @@ -248,23 +246,19 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
input_args.push_back(output);
}

auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"),
"RecordEvent",
"",
input_args,
output_args, /* output */
{}, /* attribute */
kMSDomain);
record_node = &new_node;
record_node = &CreateBottleneckNode(
graph, "RecordEvent", "forward_record", "", nullptr,
input_args, output_args);

forward_recorded_event_name = record_node->InputDefs()[0]->Name();
forward_output_name = record_node->OutputDefs()[0]->Name();
}

// Insert WaitEvent
{
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names);
AddInputEvent(graph, "backward_waited_event_id", input_args, new_input_names);

input_args.insert(std::end(input_args),
std::begin(record_node->MutableOutputDefs()),
Expand All @@ -275,14 +269,9 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
output_args.push_back(&new_output);
input = &new_output;

auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"Backward pass",
input_args,
output_args, /* output */
{}, /* attribute */
kMSDomain);
wait_node = &new_node;
wait_node = &CreateBottleneckNode(
graph, "WaitEvent", "backward_wait", "Backward pass", nullptr,
input_args, output_args);

backward_waited_event_name = wait_node->InputDefs()[0]->Name();
backward_output_name = wait_node->OutputDefs()[0]->Name();
Expand All @@ -297,12 +286,14 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
std::vector<NodeArg*>& new_node_args) {
ORT_ENFORCE(node_args.size() == new_node_args.size());
for (size_t i = 0; i < node_args.size(); ++i) {
// At this iteration, we replace node_args[i] with
// Iteration for node_args[i] and new_node_args[i].

ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name());
ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type());

for (auto& node: nodes) {
for (auto& node_arg: node->MutableInputDefs()) {
// Only replace when node's input name matches node_args[i].
if (node_arg->Name().compare(node_args[i]->Name()) != 0) {
continue;
}
Expand Down Expand Up @@ -346,10 +337,9 @@ std::string AddEventBeforeNode(
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);

// Create node which produces new_node_args from event ID and node_args.
auto name = graph.GenerateNodeName(event_op_name);
CreateBottleneckNode(graph,
event_op_type,
name,
event_op_name,
"",
event_node_arg,
node_args,
Expand Down Expand Up @@ -389,10 +379,9 @@ std::string AddEventAfterNode(
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);

// Create node which produces new_node_args from event ID and node_args.
auto name = graph.GenerateNodeName(event_op_name);
CreateBottleneckNode(graph,
event_op_type,
name,
event_op_name,
"",
event_node_arg,
node_args,
Expand Down Expand Up @@ -469,7 +458,52 @@ Status AddBackwardRecordBeforeSend(
}
}

// Insert WaitEvent and RecordEvent to the partition.
// This function inserts WaitEvent's and RecordEvent's to the input graph for
// controlling synchronization between (batch, pipeline stage)-pairs.
//
// The input graph is a pipeline's stage, which contains some Send's and Recv's.
//
// For diferent pipeline stages, they have different communication patterns as
// shown below.
//
// 1. First stage:
// FW -----------> Send ----------->
// ------> Recv ---------> BW
// 2. Middle stage:
// Recv ---------> FW -----------> Send ----------->
// ------> Recv ---------> BW -----------> Send
// 3. Last stage:
// Recv ---------> FW ----------------------------->
// ----------------------> BW -----------> Send
//
// This function inserts some event operators and those patterns become
//
// 1. First stage:
// Wait ---------> Wait -> FW -> Record -> Send -> Record ->
// Wait -> Recv -> Wait -> BW -> Record ---------> Record
// 2. Middle stage:
// Wait -> Recv -> Wait -> FW -> Record -> Send -> Record ->
// Wait -> Recv -> Wait -> BW -> Record -> Send -> Record
// 3. Last stage:
// Wait -> Recv -> Wait -> FW ----------------------------->
// ----------------------> BW -> Record -> Send -> Record
//
// To explain the meaning of those operators, we take the middle stage's pattern
// as an example:
//
// Wait-0 -> Recv -> Wait-1 -> FW -> Record-0 -> Send -> Record-1 ->
// Wait-2 -> Recv -> Wait-3 -> BW -> Record-2 -> Send -> Record-3
//
// Their meanings are listed below.
//
// Wait-0: Wait until we can start reciving forward data.
// Wait-1: Wait until we can start forward pass.
// Record-0: Tell others that forward pass is done.
// Record-1: Tell others that forward result has been passed to another stage.
// Wait-2: Wait until we can start reciving backward data.
// Wait-3: Wait until we can start backward bass.
// Record-2: Tell others that backward pass is done.
// Record-3: Tell others that backward result has been passed to another stage.
Status TransformGraphForPipeline(
Graph& graph,
std::string& forward_waited_event_name,
Expand Down Expand Up @@ -508,26 +542,26 @@ Status TransformGraphForPipeline(
}

// Names to added into this graph's input list.
// Their value may be provides as "feeds" when calling session.Run(...).
// Their values may be provides as "feeds" when calling session.Run(...).
std::vector<std::string> new_input_names;
// Names to added into this graph's output list.
// Their value may be provides as "feeds" when calling session.Run(...).
// Their values may be returned as "fetches" when calling session.Run(...).
std::vector<std::string> new_output_names;

backward_record = AddRecordBackward(
backward_record = AddBackwardRecord(
graph,
backward_send,
new_input_names,
new_output_names,
backward_recorded_event_name,
backward_record_output_name);
forward_wait = AddWaitForward(
forward_wait = AddForwardWait(
graph,
forward_recv,
new_input_names,
forward_waited_event_name,
forward_wait_output_name);
ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(
ORT_RETURN_IF_ERROR(AddOrSkipForwardRecordBackwardWait(
graph,
forward_send,
backward_recv,
Expand All @@ -552,27 +586,14 @@ Status TransformGraphForPipeline(
// 3. Last stage:
// Wait -> Recv ---------> FW ----------------------------->
// ----------------------> BW -----------> Send -> Record
//
// After applying all transformations below, we will have
// the following patterns.
//
// 1. First stage:
// Wait ---------> Wait -> FW -> Record -> Send -> Record ->
// Wait -> Recv -> Wait -> BW -> Record ---------> Record
// 2. Middle stage:
// Wait -> Recv -> Wait -> FW -> Record -> Send -> Record ->
// Wait -> Recv -> Wait -> BW -> Record -> Send -> Record
// 3. Last stage:
// Wait -> Recv -> Wait -> FW ----------------------------->
// ----------------------> BW -> Record -> Send -> Record
const bool is_first_stage = !forward_recv && forward_send && backward_recv && !backward_send;
const bool is_middle_stage = forward_recv && forward_send && backward_recv && backward_send;
const bool is_last_stage = forward_recv && !forward_send && !backward_recv && backward_send;

// One and only one of is_first_stage, is_middle_stage, and is_last_stage can be true.
const unsigned int stage_flag_sum = is_first_stage + is_middle_stage + is_last_stage;
ORT_RETURN_IF_NOT(stage_flag_sum == 1u,
"The processed graph should be classified into an stage, "
"The processed graph should be classified into a stage, "
"but we see more than one true's in the following statements. ",
"Is first stage? ", is_first_stage, ". ",
"Is middle stage? ", is_middle_stage, ". ",
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/models/runner/pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class PipelineSchedule {
// It equals to table_.size().
int num_stages_;
// Total number of batches scheduled in this pipeline.
// It equals to table_[i].size(), for i = 0, ..., num_stages_.
// It equals to table_[i].size(), for i = 0, ..., num_stages_ - 1.
int num_batches_;
};

Expand Down
14 changes: 14 additions & 0 deletions orttraining/orttraining/test/graph/gradient_graph_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1087,14 +1087,28 @@ void RetrieveSendRecvOperators(
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Send") {
if (is_backward(node)) {
// backward_send can only be assigned one valid pointer.
// If it is assigned more than once, it means we have multiple
// Send in backward pass and therefore our assumption doesn't hold.
// This check ensure that only we only update *backward_send when
// its value is NULL and guards our one-Recv assumption.
ASSERT_TRUE(!(*backward_send));
*backward_send = &node;
} else {
// Guard the uniqueness of Send in the forward pass by throwing
// when *forward_send already carries a valid pointer.
ASSERT_TRUE(!(*forward_send));
*forward_send = &node;
}
} else if (node.OpType() == "Recv") {
if (is_backward(node)) {
// Guard the uniqueness of Recv in the backward pass by throwing
// when *backward_recv already carries a valid pointer.
ASSERT_TRUE(!(*backward_recv));
*backward_recv = &node;
} else {
// Guard the uniqueness of Recv in the forwaard pass by throwing
// when *forward_recv already carries a valid pointer.
*forward_recv = &node;
}
}
Expand Down

0 comments on commit 0d11649

Please sign in to comment.