@@ -756,75 +756,72 @@ class ReshapeOpBuilder : public BaseOpBuilder {
756756
757757 private:
758758 Status AddToModelBuilderImpl (ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT;
759- static bool CanSkipReshape (const Node& node, size_t input_rank, size_t output_rank);
759+ static bool CanSkipReshape (const ModelBuilder& model_builder, const Node& node, size_t input_rank, size_t output_rank);
760760};
761761
762762void ReshapeOpBuilder::AddInitializersToSkip (ModelBuilder& model_builder, const Node& node) const {
763763 model_builder.AddInitializerToSkip (node.InputDefs ()[1 ]->Name ());
764764}
765765
766- // We can skip the Reshape if all the output edges satisfies,
767- // 1. The output of the reshape/flatten is the input 0 of the GEMM/Matmul,
766+ // We can skip the Reshape if all the output edges satisfies both the following conditions
767+ // 1. The output the reshape/flatten is not an output of the graph
768+ // 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
769+ // and not any other types of operator,
768770// and the input rank >= 2 and output_rank == 2
769771// This is because Gemm/Matmul will map to ANEURALNETWORKS_FULLY_CONNECTED in NNAPI,
770772// ANEURALNETWORKS_FULLY_CONNECTED will flatten the 2+ dim input 0 to 2d
771- // 2. Or the output the reshape/flatten is the output of the graph
772- // (no op in the graph is using the output except can be used by Gemm/Matmul satisfying condition 1 above)
773773// The reason we want to skip Reshape is that Reshape is not running on Hardware (NPU,...) in NNAPI for
774774// some CPU (e.g. Qualcomm SD for now), skipping unnecessary Reshape will prevent context switching
775775// between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution
776776// If we are going to skip the reshape, we will still add correct shape and operand type for the output in
777777// onnxruntime::nnapi::Model.
778- // If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape
779- // information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape
780- /* static */ bool ReshapeOpBuilder::CanSkipReshape (const Node& node, size_t input_rank, size_t output_rank) {
781- //
782- // TEMPORARILY DISABLED. Needs refinement.
783- //
784- // const auto& output = node.OutputDefs()[0]->Name();
785- // // We will go through all the output edges
786- // for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
787- // const auto& op_type = it->GetNode().OpType();
788- // // TODO add quantized matmul when reshape support quantized input
789- // if (op_type != "Gemm" && op_type != "MatMul") {
790- // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
791- // << " or no op is using the output (output is graph output)"
792- // << ", output name, " << output
793- // << " is used by " << op_type;
794- // return false;
795- // }
796-
797- // // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
798- // if (it->GetDstArgIndex() != 0) {
799- // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
800- // << ", output name, " << output;
801- // return false;
802- // }
803-
804- // // We only support 2d matmul/gemm here
805- // // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
806- // if (input_rank < 2 || output_rank != 2) {
807- // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
808- // << ", output name, " << output
809- // << ", the actual input_rank, " << input_rank
810- // << ", the actual output_rank, " << output_rank;
811- // return false;
812- // }
813- // }
814-
815- // // If we reach here, we have either,
816- // // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
817- // // or
818- // // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
819- // // we can skip this Reshape now
820- // LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
821- // << node.Name() << "] with output, " << output;
822- // return true;
823-
824- ORT_UNUSED_PARAMETER (node);
825- ORT_UNUSED_PARAMETER (input_rank);
826- ORT_UNUSED_PARAMETER (output_rank);
827- return false ;
778+ /* static */ bool ReshapeOpBuilder::CanSkipReshape (const ModelBuilder& model_builder, const Node& node,
779+ size_t input_rank, size_t output_rank) {
780+ const auto & output = node.OutputDefs ()[0 ]->Name ();
781+ // We will go through all the output edges
782+ for (auto it = node.OutputEdgesBegin (), end = node.OutputEdgesEnd (); it != end; ++it) {
783+ const auto & op_type = it->GetNode ().OpType ();
784+ // TODO add quantized matmul when reshape support quantized input
785+ if (op_type != " Gemm" && op_type != " MatMul" ) {
786+ LOGS_DEFAULT (VERBOSE) << " Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
787+ << " or no op is using the output (output is graph output)"
788+ << " , output name, " << output
789+ << " is used by " << op_type;
790+ return false ;
791+ }
792+
793+ // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
794+ if (it->GetDstArgIndex () != 0 ) {
795+ LOGS_DEFAULT (VERBOSE) << " Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
796+ << " , output name, " << output;
797+ return false ;
798+ }
799+
800+ // We only support 2d matmul/gemm here
801+ // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
802+ if (input_rank < 2 || output_rank != 2 ) {
803+ LOGS_DEFAULT (VERBOSE) << " Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
804+ << " , output name, " << output
805+ << " , the actual input_rank, " << input_rank
806+ << " , the actual output_rank, " << output_rank;
807+ return false ;
808+ }
809+ }
810+
811+ // If we reach here, we have all the Reshape outputs are used by gemm/matmul, or Reshape has no output edge
812+ // Check if the Reshape output is a graph output, if so we cannot skip the Reshape
813+ // We do not care the case where the Reshape output is a dead end
814+ for (const auto * node_arg : model_builder.GetGraphViewer ().GetOutputs ()) {
815+ if (node_arg->Name () == output) {
816+ LOGS_DEFAULT (VERBOSE) << " Reshape/Flatten can not be skipped when the output is a graph output"
817+ << " , output name, " << output;
818+ return false ;
819+ }
820+ }
821+
822+ LOGS_DEFAULT (VERBOSE) << " Skipping Reshape/Flatten node ["
823+ << node.Name () << " ] with output, " << output;
824+ return true ;
828825}
829826
830827/* static */ Status ReshapeOpBuilder::AddReshapeOperator (ModelBuilder& model_builder,
@@ -842,7 +839,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
842839 // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now)
843840 // We will try to see if we the skip the Reshape to prevent context switching between
844841 // NNAPI CPU impl and NNAPI hardware accelerator impl
845- if (CanSkipReshape (node, input_rank, output_rank)) {
842+ if (CanSkipReshape (model_builder, node, input_rank, output_rank)) {
846843 // Since reshape can be skipped, only register the dimension and type, with same index and new name
847844 const OperandType output_operand_type (operand_types.at (input).type , shaper[output]);
848845 model_builder.RegisterOperand (output, operand_indices.at (input), output_operand_type, false );
0 commit comments