Skip to content

Commit c49d5f1

Browse files
authored
Reenable skip flatten/reshape if it's Gemm's input (#5904)
1 parent 7823033 commit c49d5f1

File tree

1 file changed

+53
-56
lines changed
  • onnxruntime/core/providers/nnapi/nnapi_builtin/builders

1 file changed

+53
-56
lines changed

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -756,75 +756,72 @@ class ReshapeOpBuilder : public BaseOpBuilder {
756756

757757
private:
758758
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT;
759-
static bool CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank);
759+
static bool CanSkipReshape(const ModelBuilder& model_builder, const Node& node, size_t input_rank, size_t output_rank);
760760
};
761761

762762
void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
763763
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
764764
}
765765

766-
// We can skip the Reshape if all the output edges satisfies,
767-
// 1. The output of the reshape/flatten is the input 0 of the GEMM/Matmul,
766+
// We can skip the Reshape if all the output edges satisfies both the following conditions
767+
// 1. The output the reshape/flatten is not an output of the graph
768+
// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
769+
// and not any other types of operator,
768770
// and the input rank >= 2 and output_rank == 2
769771
// This is because Gemm/Matmul will map to ANEURALNETWORKS_FULLY_CONNECTED in NNAPI,
770772
// ANEURALNETWORKS_FULLY_CONNECTED will flatten the 2+ dim input 0 to 2d
771-
// 2. Or the output the reshape/flatten is the output of the graph
772-
// (no op in the graph is using the output except can be used by Gemm/Matmul satisfying condition 1 above)
773773
// The reason we want to skip Reshape is that Reshape is not running on Hardware (NPU,...) in NNAPI for
774774
// some CPU (e.g. Qualcomm SD for now), skipping unnecessary Reshape will prevent context switching
775775
// between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution
776776
// If we are going to skip the reshape, we will still add correct shape and operand type for the output in
777777
// onnxruntime::nnapi::Model.
778-
// If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape
779-
// information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape
780-
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank) {
781-
//
782-
// TEMPORARILY DISABLED. Needs refinement.
783-
//
784-
// const auto& output = node.OutputDefs()[0]->Name();
785-
// // We will go through all the output edges
786-
// for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
787-
// const auto& op_type = it->GetNode().OpType();
788-
// // TODO add quantized matmul when reshape support quantized input
789-
// if (op_type != "Gemm" && op_type != "MatMul") {
790-
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
791-
// << " or no op is using the output (output is graph output)"
792-
// << ", output name, " << output
793-
// << " is used by " << op_type;
794-
// return false;
795-
// }
796-
797-
// // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
798-
// if (it->GetDstArgIndex() != 0) {
799-
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
800-
// << ", output name, " << output;
801-
// return false;
802-
// }
803-
804-
// // We only support 2d matmul/gemm here
805-
// // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
806-
// if (input_rank < 2 || output_rank != 2) {
807-
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
808-
// << ", output name, " << output
809-
// << ", the actual input_rank, " << input_rank
810-
// << ", the actual output_rank, " << output_rank;
811-
// return false;
812-
// }
813-
// }
814-
815-
// // If we reach here, we have either,
816-
// // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
817-
// // or
818-
// // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
819-
// // we can skip this Reshape now
820-
// LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
821-
// << node.Name() << "] with output, " << output;
822-
// return true;
823-
824-
ORT_UNUSED_PARAMETER(node);
825-
ORT_UNUSED_PARAMETER(input_rank);
826-
ORT_UNUSED_PARAMETER(output_rank);
827-
return false;
778+
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const ModelBuilder& model_builder, const Node& node,
779+
size_t input_rank, size_t output_rank) {
780+
const auto& output = node.OutputDefs()[0]->Name();
781+
// We will go through all the output edges
782+
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
783+
const auto& op_type = it->GetNode().OpType();
784+
// TODO add quantized matmul when reshape support quantized input
785+
if (op_type != "Gemm" && op_type != "MatMul") {
786+
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
787+
<< " or no op is using the output (output is graph output)"
788+
<< ", output name, " << output
789+
<< " is used by " << op_type;
790+
return false;
791+
}
792+
793+
// NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
794+
if (it->GetDstArgIndex() != 0) {
795+
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
796+
<< ", output name, " << output;
797+
return false;
798+
}
799+
800+
// We only support 2d matmul/gemm here
801+
// And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
802+
if (input_rank < 2 || output_rank != 2) {
803+
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
804+
<< ", output name, " << output
805+
<< ", the actual input_rank, " << input_rank
806+
<< ", the actual output_rank, " << output_rank;
807+
return false;
808+
}
809+
}
810+
811+
// If we reach here, we have all the Reshape outputs are used by gemm/matmul, or Reshape has no output edge
812+
// Check if the Reshape output is a graph output, if so we cannot skip the Reshape
813+
// We do not care the case where the Reshape output is a dead end
814+
for (const auto* node_arg : model_builder.GetGraphViewer().GetOutputs()) {
815+
if (node_arg->Name() == output) {
816+
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can not be skipped when the output is a graph output"
817+
<< ", output name, " << output;
818+
return false;
819+
}
820+
}
821+
822+
LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
823+
<< node.Name() << "] with output, " << output;
824+
return true;
828825
}
829826

830827
/* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder,
@@ -842,7 +839,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
842839
// Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now)
843840
// We will try to see if we the skip the Reshape to prevent context switching between
844841
// NNAPI CPU impl and NNAPI hardware accelerator impl
845-
if (CanSkipReshape(node, input_rank, output_rank)) {
842+
if (CanSkipReshape(model_builder, node, input_rank, output_rank)) {
846843
// Since reshape can be skipped, only register the dimension and type, with same index and new name
847844
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
848845
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);

0 commit comments

Comments
 (0)