@@ -1767,6 +1767,10 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
17671767class SqueezeOpBuilder : public BaseOpBuilder {
17681768 public:
17691769 void AddInitializersToSkip (ModelBuilder& model_builder, const Node& node) const override ;
1770+ static Status AddSqueezeOp (ModelBuilder& model_builder,
1771+ const std::string& node_name,
1772+ const std::string& input, const std::string& output,
1773+ vector<int32_t > axes) ORT_MUST_USE_RESULT;
17701774
17711775 private:
17721776 Status AddToModelBuilderImpl (ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT;
@@ -1779,6 +1783,49 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
17791783 }
17801784}
17811785
1786+ /* static */ Status SqueezeOpBuilder::AddSqueezeOp (ModelBuilder& model_builder,
1787+ const std::string& node_name,
1788+ const std::string& input, const std::string& output,
1789+ vector<int32_t > axes) {
1790+ auto & shaper (model_builder.GetShaper ());
1791+ const auto & operand_indices (model_builder.GetOperandIndices ());
1792+ const auto & operand_types (model_builder.GetOperandTypes ());
1793+
1794+ const auto & input_shape (shaper[input]);
1795+ auto input_dims = input_shape.size ();
1796+ for (auto & axis : axes) {
1797+ axis = static_cast <int32_t >(HandleNegativeAxis (axis, input_dims));
1798+ }
1799+
1800+ // Despite the spec of ANEURALNETWORKS_SQUEEZE at
1801+ // https://developer.android.com/ndk/reference/group/neural-networks
1802+ // states, that the axes (input 1 of ANEURALNETWORKS_SQUEEZE) is optional.
1803+ //
1804+ // The actual code of NNAPI requires the axes to be provided
1805+ // https://android.googlesource.com/platform/frameworks/ml/+/master/nn/common/operations/Squeeze.cpp#31
1806+ if (axes.empty ()) { // Squeeze all
1807+ for (size_t i = 0 ; i < input_dims; i++) {
1808+ if (input_shape[i] == 1 )
1809+ axes.push_back (i);
1810+ }
1811+ }
1812+
1813+ const auto axes_name = model_builder.GetUniqueName (node_name + input + " _axes" );
1814+ Shape axes_dimen = {static_cast <uint32_t >(axes.size ())};
1815+ const OperandType axes_operand_type (Type::TENSOR_INT32, axes_dimen);
1816+ ORT_RETURN_IF_ERROR (model_builder.AddOperandFromPersistMemoryBuffer (axes_name, axes.data (), axes_operand_type));
1817+
1818+ std::vector<uint32_t > input_indices;
1819+ input_indices.push_back (operand_indices.at (input)); // input
1820+ input_indices.push_back (operand_indices.at (axes_name)); // axes
1821+
1822+ ORT_RETURN_IF_ERROR (shaper.Squeeze (input, axes, output));
1823+ const OperandType output_operand_type (operand_types.at (input).type , shaper[output]);
1824+ ORT_RETURN_IF_ERROR (model_builder.AddOperation (ANEURALNETWORKS_SQUEEZE, input_indices,
1825+ {output}, {output_operand_type}, {false }));
1826+ return Status::OK ();
1827+ }
1828+
17821829/* static */ vector<int32_t > SqueezeOpBuilder::GetAxes (ModelBuilder& model_builder, const Node& node) {
17831830 vector<int32_t > axes;
17841831 // Squeeze opset 13 use input as axes
@@ -1804,47 +1851,13 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
18041851}
18051852
18061853Status SqueezeOpBuilder::AddToModelBuilderImpl (ModelBuilder& model_builder, const Node& node) const {
1807- auto & shaper (model_builder.GetShaper ());
1808- const auto & operand_indices (model_builder.GetOperandIndices ());
1809- const auto & operand_types (model_builder.GetOperandTypes ());
1810-
18111854 auto input = node.InputDefs ()[0 ]->Name ();
18121855 if (model_builder.IsOperandNHWC (input)) {
18131856 // We want to transpose nhwc operand back to nchw before squeeze
18141857 ORT_RETURN_IF_ERROR (GetNCHWInput (model_builder, node, 0 , input));
18151858 }
18161859
1817- NodeAttrHelper helper (node);
1818- vector<int32_t > axes = GetAxes (model_builder, node);
1819- const auto & input_shape (shaper[input]);
1820- auto input_dims = input_shape.size ();
1821- for (auto & axis : axes) {
1822- axis = static_cast <int32_t >(HandleNegativeAxis (axis, input_dims));
1823- }
1824-
1825- if (axes.empty ()) { // Squeeze all
1826- for (size_t i = 0 ; i < input_dims; i++) {
1827- if (input_shape[i] == 1 )
1828- axes.push_back (i);
1829- }
1830- }
1831-
1832- const auto axes_name = model_builder.GetUniqueName (node.Name () + input + " _axes" );
1833- Shape axes_dimen = {static_cast <uint32_t >(axes.size ())};
1834- shaper.AddShape (axes_name, axes_dimen);
1835- const OperandType axes_operand_type (Type::TENSOR_INT32, axes_dimen);
1836- ORT_RETURN_IF_ERROR (model_builder.AddOperandFromPersistMemoryBuffer (axes_name, axes.data (), axes_operand_type));
1837-
1838- std::vector<uint32_t > input_indices;
1839- input_indices.push_back (operand_indices.at (input)); // input
1840- input_indices.push_back (operand_indices.at (axes_name)); // axes
1841-
1842- const auto & output = node.OutputDefs ()[0 ]->Name ();
1843- ORT_RETURN_IF_ERROR (shaper.Squeeze (input, axes, output));
1844- const OperandType output_operand_type (operand_types.at (input).type , shaper[output]);
1845- ORT_RETURN_IF_ERROR (model_builder.AddOperation (ANEURALNETWORKS_SQUEEZE, input_indices,
1846- {output}, {output_operand_type}, {false }));
1847- return Status::OK ();
1860+ return AddSqueezeOp (model_builder, node.Name (), input, node.OutputDefs ()[0 ]->Name (), GetAxes (model_builder, node));
18481861}
18491862
18501863#pragma endregion
0 commit comments