Skip to content

Commit

Permalink
nnapi add min max support (microsoft#6117)
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyu-wang authored Dec 16, 2020
1 parent 939cc9b commit b648bf6
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 13 deletions.
118 changes: 105 additions & 13 deletions onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,35 @@ Status GetNCHWInput(ModelBuilder& model_builder, const Node& node, size_t input_
return Status::OK();
}

// Transpose layouts if necessary for element wise operators with 2 inputs
// and return the layout type of output tensor
// If both inputs have same layout, the output will have the same layout
// Otherwise we will need transpose the nhwc input back to nchw, and output will be nchw
Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const Node& node,
size_t input1_idx, size_t input2_idx,
std::string& input1, std::string& input2,
bool& output_is_nhwc) ORT_MUST_USE_RESULT;
Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const Node& node,
size_t input1_idx, size_t input2_idx,
std::string& input1, std::string& input2,
bool& output_is_nhwc) {
bool input1_is_nhwc = model_builder.IsOperandNHWC(input1);
bool input2_is_nhwc = model_builder.IsOperandNHWC(input2);
output_is_nhwc = false;

if (input1_is_nhwc == input2_is_nhwc) {
output_is_nhwc = input1_is_nhwc;
} else if (input1_is_nhwc) {
// need transpose input1 back to nchw
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, input1_idx, input1));
} else { // input2_is_nhwc
// need transpose input2 back to nchw
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, input2_idx, input2));
}

return Status::OK();
}

static Status AddBinaryOperator(int32_t op_type,
ModelBuilder& model_builder,
const std::string& input1,
Expand Down Expand Up @@ -679,19 +708,9 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
std::string input2 = input_defs[b_idx]->Name();
const auto& output = node.OutputDefs()[0]->Name();

bool input1_is_nhwc = model_builder.IsOperandNHWC(input1);
bool input2_is_nhwc = model_builder.IsOperandNHWC(input2);
bool output_is_nhwc = false;

if (input1_is_nhwc == input2_is_nhwc) {
output_is_nhwc = input1_is_nhwc;
} else if (input1_is_nhwc) {
// need transpose input1 back to nchw
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, a_idx, input1));
} else { // input2_is_nhwc
// need transpose input2 back to nchw
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, b_idx, input2));
}
ORT_RETURN_IF_ERROR(
TransposeBinaryOpInputLayout(model_builder, node, a_idx, b_idx, input1, input2, output_is_nhwc));

float a_scale = 0.0f,
b_scale = 0.0f,
Expand Down Expand Up @@ -2221,7 +2240,75 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
return ReshapeOpBuilder::AddReshapeOperator(model_builder, node, input, shape);
}

#pragma endregion op_reshape
#pragma endregion

#pragma region op_minmax

class MinMaxOpBuilder : public BaseOpBuilder {
public:
static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
static Status AddMinMaxOperator(ModelBuilder& model_builder, const Node& node,
const std::string& input1, const std::string& input2,
bool output_is_nhwc) ORT_MUST_USE_RESULT;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT;
};

/* static */ void MinMaxOpBuilder::CreateSharedOpBuilder(
const std::string& op_type, OpBuilderRegistrations& op_registrations) {
CreateSharedOpBuilderImpl<MinMaxOpBuilder>(
op_type, op_registrations,
{
"Min",
"Max",
});
}

/* static */ Status MinMaxOpBuilder::AddMinMaxOperator(ModelBuilder& model_builder, const Node& node,
const std::string& input1, const std::string& input2,
bool output_is_nhwc) {
auto& shaper(model_builder.GetShaper());
const auto& operand_indices(model_builder.GetOperandIndices());
const auto& operand_types(model_builder.GetOperandTypes());

const auto& output = node.OutputDefs()[0]->Name();

const auto& op_type(node.OpType());
int32_t op_code;
if (op_type == "Min")
op_code = ANEURALNETWORKS_MINIMUM;
else if (op_type == "Max")
op_code = ANEURALNETWORKS_MAXIMUM;
else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MinMaxOpBuilder, unknown op: ", op_type);
}

std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input1)); // input 1
input_indices.push_back(operand_indices.at(input2)); // input 2
ORT_RETURN_IF_ERROR(shaper.Eltwise(input1, input2, output));
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output]);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));

return Status::OK();
}

Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) const {
const auto input_defs(node.InputDefs());
std::string input1 = input_defs[0]->Name();
std::string input2 = input_defs[1]->Name();
bool output_is_nhwc = false;
ORT_RETURN_IF_ERROR(TransposeBinaryOpInputLayout(model_builder, node,
0 /* input1_idx */,
1 /* input2_idx */,
input1, input2, output_is_nhwc));

return AddMinMaxOperator(model_builder, node, input1, input2, output_is_nhwc);
}

#pragma endregion

#pragma region CreateGetOpBuilders

Expand Down Expand Up @@ -2297,6 +2384,11 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
NNAPI_EP_ADD_SINGLE_OP_BUILDER("Resize", ResizeOpBuilder);
NNAPI_EP_ADD_SINGLE_OP_BUILDER("Flatten", FlattenOpBuilder);

{
NNAPI_EP_ADD_SHARED_OP_BUILDER("Min", MinMaxOpBuilder);
NNAPI_EP_ADD_SHARED_OP_BUILDER("Max", MinMaxOpBuilder);
}

return op_registrations;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,49 @@ bool FlattenOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* i

#pragma endregion

#pragma region op_minmax

class MinMaxOpSupportChecker : public BaseOpSupportChecker {
public:
static void CreateSharedOpSupportChecker(
const std::string& op_type, OpSupportCheckerRegistrations& op_registrations);

private:
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override {
return 29;
}

// Min/Max opset 5- uses consumed_inputs attribute which is not supported for now
int GetMinSupportedOpSet(const Node& /* node */) const override { return 6; }

bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const OpSupportCheckParams& params) const override;
};

/* static */ void MinMaxOpSupportChecker::CreateSharedOpSupportChecker(
const std::string& op_type, OpSupportCheckerRegistrations& op_registrations) {
CreateSharedOpSupportCheckerImpl<MinMaxOpSupportChecker>(
op_type, op_registrations,
{
"Min",
"Max",
});
}

bool MinMaxOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const OpSupportCheckParams& /* params */) const {
// TODO support 2+ inputs for Min/Max op
if (node.InputDefs().size() != 2) {
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() << "] only supports 2 inputs, "
<< "actual input number, " << node.InputDefs().size();
return false;
}

return true;
}

#pragma endregion

#pragma region CreateGetOpSupportCheckers

// The reason we use macros to create OpBuilders is for easy exclusion in build if certain op(s) are not used
Expand Down Expand Up @@ -1373,6 +1416,10 @@ static OpSupportCheckerRegistrations CreateOpSupportCheckerRegistrations() {
NNAPI_EP_ADD_SINGLE_OP_SUPPORT_CHECKER("Resize", ResizeOpSupportChecker);
NNAPI_EP_ADD_SINGLE_OP_SUPPORT_CHECKER("Flatten", FlattenOpSupportChecker);

{
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Min", MinMaxOpSupportChecker);
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Max", MinMaxOpSupportChecker);
}
return op_registrations;
}

Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,21 @@ TEST(MathOpTest, Min_12_Float) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); //TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_Float_2_Input) {
OpTester test("Min", 12);
test.AddInput<float>("data_2", {3, 3},
{10.0f, 20.0f, 30.0f,
40.0f, 50.0f, 60.0f,
-70.0f, -80.0f, -90.0f});
test.AddInput<float>("data_1", {3, 1},
{-1.0f, 20.0f, 300.0f});
test.AddOutput<float>("min", {3, 3},
{-1.0f, -1.0f, -1.0f,
20.0f, 20.0f, 20.0f,
-70.0f, -80.0f, -90.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); //TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_Double) {
OpTester test("Min", 12);
test.AddInput<double>("data_0", {1, 3},
Expand Down

0 comments on commit b648bf6

Please sign in to comment.