From 0d11649bb304a801e6e3743d72cf61c11db596d2 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Sun, 17 May 2020 14:08:33 -0700 Subject: [PATCH] Address comments from #3823 and polish code (#3964) * Address comments from #3823 and polish code * One line --- .../core/graph/pipeline_transformer.cc | 163 ++++++++++-------- .../orttraining/models/runner/pipeline.h | 2 +- .../test/graph/gradient_graph_builder_test.cc | 14 ++ 3 files changed, 107 insertions(+), 72 deletions(-) diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.cc b/orttraining/orttraining/core/graph/pipeline_transformer.cc index 576bb61dd92b5..85b89a2e44bb3 100644 --- a/orttraining/orttraining/core/graph/pipeline_transformer.cc +++ b/orttraining/orttraining/core/graph/pipeline_transformer.cc @@ -30,11 +30,11 @@ NodeArg& CreateInt64NodeArg(Graph& graph, const std::string& name) { return node_arg; } -void AddInputEvent(Graph& graph, const std::string& op_name, - bool is_forward, +void AddInputEvent(Graph& graph, + const std::string& event_name, std::vector& input_args, std::vector& new_input_names) { - auto& event_id = CreateInt64NodeArg(graph, op_name + (is_forward ? "_fw" : "_bw") + "_event_id"); + auto& event_id = CreateInt64NodeArg(graph, event_name); new_input_names.push_back(event_id.Name()); input_args.push_back(&event_id); } @@ -91,7 +91,7 @@ std::vector CreateMirrorNodeArgs( // Create a node with input schema [event, input1, input2, ..., inputN] and // output schema [input1, input2, ..., inputN] -void CreateBottleneckNode(Graph& graph, +Node& CreateBottleneckNode(Graph& graph, const std::string& op_type, const std::string& op_name, const std::string& description, @@ -102,7 +102,8 @@ void CreateBottleneckNode(Graph& graph, if (event) { input_node_args.insert(input_node_args.begin(), event); } - graph.AddNode( + + return graph.AddNode( name, op_type, description, @@ -112,14 +113,14 @@ void CreateBottleneckNode(Graph& graph, kMSDomain); } -Node* AddRecordBackward(Graph& graph, +Node* AddBackwardRecord(Graph& graph, Node* backward_send, std::vector& new_input_names, std::vector& new_output_names, std::string &event_id_tensor_name, std::string &output_tensor_name) { std::vector input_args; - AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names); + AddInputEvent(graph, "backward_recorded_event_id", input_args, new_input_names); std::vector output_args{}; if (backward_send) { @@ -138,14 +139,9 @@ Node* AddRecordBackward(Graph& graph, output_args.push_back(&new_output); new_output_names.push_back(new_output.Name()); - Node* record_node = &(graph.AddNode( - graph.GenerateNodeName("RecordEvent"), - "RecordEvent", - "Backward pass", - input_args, - output_args, - nullptr, - kMSDomain)); + Node* record_node = &CreateBottleneckNode( + graph, "RecordEvent", "backward_record", "Backward pass", nullptr, + input_args, output_args); // First input argument is the recorded event ID tensor. event_id_tensor_name = input_args.front()->Name(); @@ -156,7 +152,7 @@ Node* AddRecordBackward(Graph& graph, return record_node; } -Node* AddWaitForward(Graph& graph, +Node* AddForwardWait(Graph& graph, Node* /* forward_recv */, std::vector& new_input_names, std::string& forward_waited_event_name, @@ -179,7 +175,7 @@ Node* AddWaitForward(Graph& graph, std::vector input_args; std::vector output_args; - AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names); + AddInputEvent(graph, "forward_waited_event_id", input_args, new_input_names); const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); if (graph_inputs.size() == 0){ @@ -199,20 +195,20 @@ Node* AddWaitForward(Graph& graph, } } } - Node* wait_node = &(graph.AddNode( - graph.GenerateNodeName("WaitEvent"), - "WaitEvent", - "", - input_args, - output_args, - nullptr, - kMSDomain)); + + Node* wait_node = &CreateBottleneckNode( + graph, "WaitEvent", "backward_record", "", nullptr, + input_args, output_args); + forward_waited_event_name = input_args.front()->Name(); output_tensor_name = output_args.front()->Name(); + return wait_node; } -Status AddOrSkipRecordForwardWaitBackward(Graph& graph, +// If the input "graph" is the last pipeline stage, this function don't add any +// event operators. +Status AddOrSkipForwardRecordBackwardWait(Graph& graph, Node* forward_send, Node* backward_recv, std::vector& new_input_names, @@ -227,11 +223,13 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, } if (!forward_send && !backward_recv){ - // Last partition doesn't have send forwrad and recv backward. No insert needed. + // Last partition doesn't have send forwrad and recv backward. No insert + // needed. return Status::OK(); } - // if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between. + // if we have a send forward op followed by a recv backward op, insert + // WaitEvent and RecordEvent in between. Node* record_node = nullptr; Node* wait_node = nullptr; @@ -239,7 +237,7 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, { std::vector input_args; std::vector output_args; - AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names); + AddInputEvent(graph, "forward_recorded_event_id", input_args, new_input_names); // Add send forward op's output as record op's input and output for (auto& output : forward_send->MutableOutputDefs()) { @@ -248,23 +246,19 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, input_args.push_back(output); } - auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"), - "RecordEvent", - "", - input_args, - output_args, /* output */ - {}, /* attribute */ - kMSDomain); - record_node = &new_node; + record_node = &CreateBottleneckNode( + graph, "RecordEvent", "forward_record", "", nullptr, + input_args, output_args); forward_recorded_event_name = record_node->InputDefs()[0]->Name(); forward_output_name = record_node->OutputDefs()[0]->Name(); } + // Insert WaitEvent { std::vector input_args; std::vector output_args; - AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names); + AddInputEvent(graph, "backward_waited_event_id", input_args, new_input_names); input_args.insert(std::end(input_args), std::begin(record_node->MutableOutputDefs()), @@ -275,14 +269,9 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, output_args.push_back(&new_output); input = &new_output; - auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"), - "WaitEvent", - "Backward pass", - input_args, - output_args, /* output */ - {}, /* attribute */ - kMSDomain); - wait_node = &new_node; + wait_node = &CreateBottleneckNode( + graph, "WaitEvent", "backward_wait", "Backward pass", nullptr, + input_args, output_args); backward_waited_event_name = wait_node->InputDefs()[0]->Name(); backward_output_name = wait_node->OutputDefs()[0]->Name(); @@ -297,12 +286,14 @@ void ReplaceNodeArgs(std::vector& nodes, std::vector& new_node_args) { ORT_ENFORCE(node_args.size() == new_node_args.size()); for (size_t i = 0; i < node_args.size(); ++i) { - // At this iteration, we replace node_args[i] with + // Iteration for node_args[i] and new_node_args[i]. + ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name()); ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type()); for (auto& node: nodes) { for (auto& node_arg: node->MutableInputDefs()) { + // Only replace when node's input name matches node_args[i]. if (node_arg->Name().compare(node_args[i]->Name()) != 0) { continue; } @@ -346,10 +337,9 @@ std::string AddEventBeforeNode( auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name); // Create node which produces new_node_args from event ID and node_args. - auto name = graph.GenerateNodeName(event_op_name); CreateBottleneckNode(graph, event_op_type, - name, + event_op_name, "", event_node_arg, node_args, @@ -389,10 +379,9 @@ std::string AddEventAfterNode( auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name); // Create node which produces new_node_args from event ID and node_args. - auto name = graph.GenerateNodeName(event_op_name); CreateBottleneckNode(graph, event_op_type, - name, + event_op_name, "", event_node_arg, node_args, @@ -469,7 +458,52 @@ Status AddBackwardRecordBeforeSend( } } -// Insert WaitEvent and RecordEvent to the partition. +// This function inserts WaitEvent's and RecordEvent's to the input graph for +// controlling synchronization between (batch, pipeline stage)-pairs. +// +// The input graph is a pipeline's stage, which contains some Send's and Recv's. +// +// For diferent pipeline stages, they have different communication patterns as +// shown below. +// +// 1. First stage: +// FW -----------> Send -----------> +// ------> Recv ---------> BW +// 2. Middle stage: +// Recv ---------> FW -----------> Send -----------> +// ------> Recv ---------> BW -----------> Send +// 3. Last stage: +// Recv ---------> FW -----------------------------> +// ----------------------> BW -----------> Send +// +// This function inserts some event operators and those patterns become +// +// 1. First stage: +// Wait ---------> Wait -> FW -> Record -> Send -> Record -> +// Wait -> Recv -> Wait -> BW -> Record ---------> Record +// 2. Middle stage: +// Wait -> Recv -> Wait -> FW -> Record -> Send -> Record -> +// Wait -> Recv -> Wait -> BW -> Record -> Send -> Record +// 3. Last stage: +// Wait -> Recv -> Wait -> FW -----------------------------> +// ----------------------> BW -> Record -> Send -> Record +// +// To explain the meaning of those operators, we take the middle stage's pattern +// as an example: +// +// Wait-0 -> Recv -> Wait-1 -> FW -> Record-0 -> Send -> Record-1 -> +// Wait-2 -> Recv -> Wait-3 -> BW -> Record-2 -> Send -> Record-3 +// +// Their meanings are listed below. +// +// Wait-0: Wait until we can start reciving forward data. +// Wait-1: Wait until we can start forward pass. +// Record-0: Tell others that forward pass is done. +// Record-1: Tell others that forward result has been passed to another stage. +// Wait-2: Wait until we can start reciving backward data. +// Wait-3: Wait until we can start backward bass. +// Record-2: Tell others that backward pass is done. +// Record-3: Tell others that backward result has been passed to another stage. Status TransformGraphForPipeline( Graph& graph, std::string& forward_waited_event_name, @@ -508,26 +542,26 @@ Status TransformGraphForPipeline( } // Names to added into this graph's input list. - // Their value may be provides as "feeds" when calling session.Run(...). + // Their values may be provides as "feeds" when calling session.Run(...). std::vector new_input_names; // Names to added into this graph's output list. - // Their value may be provides as "feeds" when calling session.Run(...). + // Their values may be returned as "fetches" when calling session.Run(...). std::vector new_output_names; - backward_record = AddRecordBackward( + backward_record = AddBackwardRecord( graph, backward_send, new_input_names, new_output_names, backward_recorded_event_name, backward_record_output_name); - forward_wait = AddWaitForward( + forward_wait = AddForwardWait( graph, forward_recv, new_input_names, forward_waited_event_name, forward_wait_output_name); - ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward( + ORT_RETURN_IF_ERROR(AddOrSkipForwardRecordBackwardWait( graph, forward_send, backward_recv, @@ -552,19 +586,6 @@ Status TransformGraphForPipeline( // 3. Last stage: // Wait -> Recv ---------> FW -----------------------------> // ----------------------> BW -----------> Send -> Record - // - // After applying all transformations below, we will have - // the following patterns. - // - // 1. First stage: - // Wait ---------> Wait -> FW -> Record -> Send -> Record -> - // Wait -> Recv -> Wait -> BW -> Record ---------> Record - // 2. Middle stage: - // Wait -> Recv -> Wait -> FW -> Record -> Send -> Record -> - // Wait -> Recv -> Wait -> BW -> Record -> Send -> Record - // 3. Last stage: - // Wait -> Recv -> Wait -> FW -----------------------------> - // ----------------------> BW -> Record -> Send -> Record const bool is_first_stage = !forward_recv && forward_send && backward_recv && !backward_send; const bool is_middle_stage = forward_recv && forward_send && backward_recv && backward_send; const bool is_last_stage = forward_recv && !forward_send && !backward_recv && backward_send; @@ -572,7 +593,7 @@ Status TransformGraphForPipeline( // One and only one of is_first_stage, is_middle_stage, and is_last_stage can be true. const unsigned int stage_flag_sum = is_first_stage + is_middle_stage + is_last_stage; ORT_RETURN_IF_NOT(stage_flag_sum == 1u, - "The processed graph should be classified into an stage, " + "The processed graph should be classified into a stage, " "but we see more than one true's in the following statements. ", "Is first stage? ", is_first_stage, ". ", "Is middle stage? ", is_middle_stage, ". ", diff --git a/orttraining/orttraining/models/runner/pipeline.h b/orttraining/orttraining/models/runner/pipeline.h index 280986d84af79..6464a78906f13 100644 --- a/orttraining/orttraining/models/runner/pipeline.h +++ b/orttraining/orttraining/models/runner/pipeline.h @@ -87,7 +87,7 @@ class PipelineSchedule { // It equals to table_.size(). int num_stages_; // Total number of batches scheduled in this pipeline. - // It equals to table_[i].size(), for i = 0, ..., num_stages_. + // It equals to table_[i].size(), for i = 0, ..., num_stages_ - 1. int num_batches_; }; diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 2958279b7c5fa..1cbf5800aa291 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -1087,14 +1087,28 @@ void RetrieveSendRecvOperators( for (auto& node : graph.Nodes()) { if (node.OpType() == "Send") { if (is_backward(node)) { + // backward_send can only be assigned one valid pointer. + // If it is assigned more than once, it means we have multiple + // Send in backward pass and therefore our assumption doesn't hold. + // This check ensure that only we only update *backward_send when + // its value is NULL and guards our one-Recv assumption. + ASSERT_TRUE(!(*backward_send)); *backward_send = &node; } else { + // Guard the uniqueness of Send in the forward pass by throwing + // when *forward_send already carries a valid pointer. + ASSERT_TRUE(!(*forward_send)); *forward_send = &node; } } else if (node.OpType() == "Recv") { if (is_backward(node)) { + // Guard the uniqueness of Recv in the backward pass by throwing + // when *backward_recv already carries a valid pointer. + ASSERT_TRUE(!(*backward_recv)); *backward_recv = &node; } else { + // Guard the uniqueness of Recv in the forwaard pass by throwing + // when *forward_recv already carries a valid pointer. *forward_recv = &node; } }