diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 9209b5fdeae22..0d4291a3b8b31 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -359,6 +359,30 @@ void IterateSubgraphFromNode(Graph& graph, } } // namespace +void RemovePrintDensityFlag(Graph& graph, + const std::vector& node_topology_list, + bool& modified, + const logging::Logger& logger) { + for (auto node_index : node_topology_list) { + Node* node = graph.GetNode(node_index); + if (node == nullptr) { + continue; + } + if (graph_utils::IsSupportedOptypeVersionAndDomain(*node, "PythonOp", {1}, kMSDomain) && + static_cast(node->GetAttributes().at("func_name").s()) == kFlagAndPrintDensityFuncName) { + if (graph_utils::CanRemoveNode(graph, *node, logger)) { + if (graph_utils::RemoveNode(graph, *node)) { + modified = true; + } else { + LOG_DEBUG_INFO(logger, "Failed to remove node " + node->Name() + "(" + node->OpType() + ")"); + } + } else { + LOG_DEBUG_INFO(logger, "Can not remove node " + node->Name() + "(" + node->OpType() + ")"); + } + } + } +} + Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { LOG_DEBUG_INFO(logger, "Enter PaddingElimination"); @@ -392,10 +416,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev node.InputDefs()[1]->Exists() && node.InputDefs()[1]->Shape() && node.InputDefs()[1]->Shape()->dim_size() >= 2) { - const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd()); - if (outputNodeCount != 1) { - continue; - } Node* embedding_input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[1]->Name()); if (embedding_input_node == nullptr || !graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_input_node, "PythonOp", {1}, kMSDomain) || @@ -404,21 +424,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node"); continue; } - if (!print_density_) { - if (graph_utils::CanRemoveNode(graph, *embedding_input_node, logger)) { - if (graph_utils::RemoveNode(graph, *embedding_input_node)) { - modified = true; - } else { - LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_input_node->Name() + - "(" + embedding_input_node->OpType() + ")"); - continue; - } - } else { - LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_input_node->Name() + - "(" + embedding_input_node->OpType() + ")"); - continue; - } - } const ONNX_NAMESPACE::TensorProto* padding_initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); if (padding_initializer != nullptr && @@ -430,19 +435,22 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev continue; } embedding_node = &node; - input_ids_arg = embedding_node->MutableInputDefs()[1]; - for (auto output_defs : embedding_node->MutableOutputDefs()) { - subgraph.insert(output_defs); - } break; } } } + if (!print_density_) { + RemovePrintDensityFlag(graph, node_topology_list, modified, logger); + } if (!embedding_node) { LOG_DEBUG_INFO(logger, "Exit PaddingElimination optimization for not finding any valid embedding node."); return Status::OK(); } + input_ids_arg = embedding_node->MutableInputDefs()[1]; + for (auto output_defs : embedding_node->MutableOutputDefs()) { + subgraph.insert(output_defs); + } if (!input_ids_arg->Shape()) { LOG_DEBUG_INFO(logger, "Exit PaddingElimination optimization for not finding shape of input_ids.");