Skip to content

Commit 8736865

Browse files
authored
Make NNAPI EP reject nodes with no-shape inputs (#5927)
1 parent fddbd89 commit 8736865

File tree

2 files changed

+75
-18
lines changed

2 files changed

+75
-18
lines changed

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,14 @@ class BaseOpSupportChecker : public IOpSupportChecker {
8080
return 27;
8181
}
8282

83-
virtual bool HasSupportedInputs(const Node& node) const;
83+
virtual bool HasSupportedInputsImpl(const Node& node) const;
8484

8585
virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; }
8686
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 13; }
87+
88+
private:
8789
bool HasSupportedOpSet(const Node& node) const;
90+
bool HasSupportedInputs(const Node& node) const;
8891
};
8992

9093
/* static */ void BaseOpSupportChecker::CreateSharedOpSupportChecker(
@@ -121,16 +124,23 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer
121124
}
122125

123126
bool BaseOpSupportChecker::HasSupportedInputs(const Node& node) const {
127+
// We do not support unknown(null) input shape
128+
for (const auto* input : node.InputDefs()) {
129+
if (!input->Shape()) {
130+
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType()
131+
<< "] Input [" << input->Name() << "] has no shape";
132+
return false;
133+
}
134+
}
135+
136+
return HasSupportedInputsImpl(node);
137+
}
138+
139+
bool BaseOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
124140
// We only check the type of input 0 by default
125141
// specific op builder can override this
126142
const auto& input = *node.InputDefs()[0];
127143

128-
if (nullptr == input.Shape()) {
129-
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
130-
<< "] Input shape is null";
131-
return false;
132-
}
133-
134144
int32_t input_type;
135145
if (!GetType(input, input_type))
136146
return false;
@@ -170,7 +180,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
170180
int32_t GetMinSupportedSdkVer(const Node& node, const OpSupportCheckParams& params) const override;
171181
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
172182
const OpSupportCheckParams& params) const override;
173-
bool HasSupportedInputs(const Node& node) const override;
183+
bool HasSupportedInputsImpl(const Node& node) const override;
174184
int GetMinSupportedOpSet(const Node& node) const override;
175185
};
176186

@@ -206,9 +216,9 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
206216
return 1;
207217
}
208218

209-
bool BinaryOpSupportChecker::HasSupportedInputs(const Node& node) const {
219+
bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
210220
if (node.OpType() != "QLinearAdd")
211-
return BaseOpSupportChecker::HasSupportedInputs(node);
221+
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
212222

213223
// QLinearAdd
214224
if (!HasValidBinaryOpQuantizedInputs(node))
@@ -511,7 +521,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
511521
return params.use_nchw ? 29 : 28;
512522
}
513523

514-
bool HasSupportedInputs(const Node& node) const override;
524+
bool HasSupportedInputsImpl(const Node& node) const override;
515525
};
516526

517527
/* static */ void ConvOpSupportChecker::CreateSharedOpSupportChecker(
@@ -524,9 +534,9 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
524534
});
525535
}
526536

527-
bool ConvOpSupportChecker::HasSupportedInputs(const Node& node) const {
537+
bool ConvOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
528538
if (node.OpType() != "QLinearConv")
529-
return BaseOpSupportChecker::HasSupportedInputs(node);
539+
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
530540

531541
// QLinearConv only supports input of uint8 for now
532542
if (!HasValidBinaryOpQuantizedInputs(node))
@@ -683,13 +693,13 @@ class GemmOpSupportChecker : public BaseOpSupportChecker {
683693
private:
684694
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
685695
const OpSupportCheckParams& params) const override;
686-
bool HasSupportedInputs(const Node& node) const override;
696+
bool HasSupportedInputsImpl(const Node& node) const override;
687697
int GetMinSupportedOpSet(const Node& node) const override;
688698
};
689699

690-
bool GemmOpSupportChecker::HasSupportedInputs(const Node& node) const {
700+
bool GemmOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
691701
if (node.OpType() != "QLinearMatMul")
692-
return BaseOpSupportChecker::HasSupportedInputs(node);
702+
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
693703

694704
// QLinearMatMul
695705
if (!HasValidBinaryOpQuantizedInputs(node))
@@ -990,7 +1000,7 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker {
9901000
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override {
9911001
return 29;
9921002
}
993-
bool HasSupportedInputs(const Node& node) const override;
1003+
bool HasSupportedInputsImpl(const Node& node) const override;
9941004
};
9951005

9961006
bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
@@ -1007,7 +1017,7 @@ bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensor
10071017
return true;
10081018
}
10091019

1010-
bool DequantizeLinearOpSupportChecker::HasSupportedInputs(const Node& node) const {
1020+
bool DequantizeLinearOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
10111021
int32_t input_type;
10121022
if (!GetType(*node.InputDefs()[0], input_type))
10131023
return false;

onnxruntime/test/providers/nnapi/nnapi_basic_test.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,53 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
155155
<< "Some nodes should have been taken by the NNAPI EP";
156156
#endif
157157
}
158+
159+
TEST(NnapiExecutionProviderTest, TestNoShapeInputModel) {
160+
const ORTCHAR_T* model_file_name = ORT_TSTR("input_with_no_shape_test_graph.onnx");
161+
162+
{ // Create the model with 2 add nodes, the graph has 2 inputs with no shape
163+
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
164+
auto& graph = model.MainGraph();
165+
std::vector<onnxruntime::NodeArg*> inputs;
166+
std::vector<onnxruntime::NodeArg*> outputs;
167+
168+
// FLOAT tensor without shape
169+
ONNX_NAMESPACE::TypeProto float_tensor;
170+
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
171+
172+
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
173+
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
174+
inputs.push_back(&input_arg_1);
175+
inputs.push_back(&input_arg_2);
176+
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
177+
outputs.push_back(&output_arg);
178+
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
179+
180+
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
181+
inputs.clear();
182+
inputs.push_back(&output_arg);
183+
inputs.push_back(&input_arg_3);
184+
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
185+
outputs.clear();
186+
outputs.push_back(&output_arg_2);
187+
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
188+
189+
ASSERT_STATUS_OK(graph.Resolve());
190+
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
191+
}
192+
193+
// test load only
194+
// since we know NNAPI supports Add op, but both Add ops in the graph has no input shape
195+
// verify the entire graph will not be assigned to NNAPI EP
196+
SessionOptions so;
197+
InferenceSessionWrapper session_object{so, GetEnvironment()};
198+
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
199+
ASSERT_STATUS_OK(session_object.Load(model_file_name));
200+
ASSERT_STATUS_OK(session_object.Initialize());
201+
ASSERT_EQ(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
202+
<< "No node should be taken by the NNAPI EP";
203+
}
204+
158205
#endif // !(ORT_MINIMAL_BUILD
159206

160207
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {

0 commit comments

Comments
 (0)