-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU]: Added impl of SearchSorted op #27036
base: master
Are you sure you want to change the base?
Changes from 7 commits
f721e7f
dc4b384
ea10ba8
d2e0117
8b95ffa
36026bd
88e2a94
17a8faf
2d7e17b
3114067
96a4668
e126353
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "search_sorted.h" | ||
|
||
#include "openvino/op/search_sorted.hpp" | ||
#include "openvino/reference/search_sorted.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
SearchSorted::SearchSorted(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) | ||
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { | ||
std::string errorMessage; | ||
if (!isSupportedOperation(op, errorMessage)) { | ||
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); | ||
} | ||
const auto ss_op = ov::as_type_ptr<const ov::op::v15::SearchSorted>(op); | ||
right_mode = ss_op->get_right_mode(); | ||
} | ||
|
||
bool SearchSorted::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept { | ||
try { | ||
if (!ov::is_type<ov::op::v15::SearchSorted>(op)) { | ||
errorMessage = "Only opset15 SearchSorted operation is supported"; | ||
return false; | ||
} | ||
} catch (...) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
void SearchSorted::getSupportedDescriptors() { | ||
// Validation is already done in the ov::opset15::SearchSorted. | ||
} | ||
|
||
void SearchSorted::initSupportedPrimitiveDescriptors() { | ||
if (!supportedPrimitiveDescriptors.empty()) | ||
return; | ||
|
||
ov::element::Type inputPrec = getOriginalInputPrecisionAtPort(0); | ||
ov::element::Type outputPrec = getOriginalOutputPrecisionAtPort(0); | ||
|
||
addSupportedPrimDesc({{LayoutType::ncsp, inputPrec}, {LayoutType::ncsp, inputPrec}}, | ||
{{LayoutType::ncsp, outputPrec}}, | ||
impl_desc_type::ref); | ||
} | ||
|
||
bool SearchSorted::created() const { | ||
return getType() == Type::SearchSorted; | ||
} | ||
|
||
bool SearchSorted::needPrepareParams() const { | ||
return false; | ||
} | ||
|
||
void SearchSorted::executeDynamicImpl(dnnl::stream strm) { | ||
execute(strm); | ||
} | ||
|
||
template <typename INPUT_TYPE, typename OUTPUT_TYPE> | ||
void SearchSorted::executeImpl() { | ||
ov::reference::search_sorted<INPUT_TYPE, OUTPUT_TYPE>(getSrcDataAtPortAs<const INPUT_TYPE>(0), | ||
getSrcDataAtPortAs<const INPUT_TYPE>(1), | ||
getDstDataAtPortAs<OUTPUT_TYPE>(0), | ||
ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()}, | ||
ov::Shape{getSrcMemoryAtPort(1)->getStaticDims()}, | ||
right_mode); | ||
} | ||
|
||
namespace { | ||
struct SearchSortedContext { | ||
SearchSorted& node; | ||
}; | ||
} // namespace | ||
|
||
template <typename T> | ||
struct SearchSorted::SearchSortedExecute { | ||
using TInputType = typename std::tuple_element<0, T>::type; | ||
using TOutputType = typename std::tuple_element<1, T>::type; | ||
|
||
void operator()(SearchSortedContext& ctx) { | ||
ctx.node.executeImpl<TInputType, TOutputType>(); | ||
} | ||
}; | ||
void SearchSorted::execute(dnnl::stream strm) { | ||
auto inputPrecision = getParentEdgeAt(0)->getMemory().getDesc().getPrecision(); | ||
auto outputPrecision = getChildEdgeAt(0)->getMemory().getDesc().getPrecision(); | ||
|
||
SearchSortedContext ctx = {*this}; | ||
|
||
#define CASE(OV_TYPE) \ | ||
OV_CASE2(OV_TYPE, ov::element::i64, ov::element_type_traits<OV_TYPE>::value_type, int64_t), \ | ||
OV_CASE2(OV_TYPE, ov::element::i32, ov::element_type_traits<OV_TYPE>::value_type, int32_t) | ||
|
||
OV_SWITCH(intel_cpu, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since OV_SWITCH doesn't check if any condition is satisfied, we need to guarantee inputPrecision and outputPrecision are in supported range. In order to do that lets adjust precision which we use in initSupportedPrimitiveDescriptors. See as as example https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/scatter_update.cpp#L273-L283. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed - as discussed in separate channel. |
||
SearchSortedExecute, | ||
ctx, | ||
std::tie(inputPrecision, outputPrecision), | ||
dmitry-gorokhov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
CASE(ov::element::f32), | ||
CASE(ov::element::f16), | ||
CASE(ov::element::bf16), | ||
CASE(ov::element::i8), | ||
CASE(ov::element::u8)) | ||
|
||
#undef CASE | ||
} | ||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "node.h" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
class SearchSorted : public Node { | ||
public: | ||
SearchSorted(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context); | ||
|
||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept; | ||
void getSupportedDescriptors() override; | ||
void initSupportedPrimitiveDescriptors() override; | ||
void execute(dnnl::stream strm) override; | ||
bool created() const override; | ||
bool needPrepareParams() const override; | ||
void executeDynamicImpl(dnnl::stream strm) override; | ||
|
||
private: | ||
template <typename INPUT_TYPE, typename OUTPUT_TYPE> | ||
void executeImpl(); | ||
|
||
template <typename T> | ||
struct SearchSortedExecute; | ||
|
||
bool right_mode = false; | ||
}; | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "single_op_tests/search_sorted.hpp" | ||
|
||
namespace ov { | ||
namespace test { | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_SearchSortedTest, | ||
SearchSortedLayerTest, | ||
::testing::Combine(::testing::ValuesIn(SearchSortedLayerTest::GenerateParams()), | ||
testing::Values(ElementType::f32, ElementType::f16), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add i8 at least There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added tests for i64 and u32 - since they are used in the model that needs SearchSorted. I cannot add currently test for i8 due to a bug in data_utils.hpp:317, which causes generate_input to loop infinitively. After fixing it, it will work, but fixing that bug is orthogonal to this PR and should be done in separate PR. |
||
testing::Values(ov::test::utils::DEVICE_CPU)), | ||
SearchSortedLayerTest::getTestCaseName); | ||
|
||
} // namespace test | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "shared_test_classes/single_op/search_sorted.hpp" | ||
|
||
namespace ov { | ||
namespace test { | ||
TEST_P(SearchSortedLayerTest, Inference) { | ||
run(); | ||
} | ||
} // namespace test | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "shared_test_classes/base/ov_subgraph.hpp" | ||
|
||
namespace ov { | ||
namespace test { | ||
|
||
using SearchSortedSpecificParams = std::tuple<InputShape, // sorted shape | ||
InputShape, // values shape | ||
bool>; | ||
|
||
using SearchSortedLayerTestParams = std::tuple<SearchSortedSpecificParams, ElementType, std::string>; | ||
|
||
class SearchSortedLayerTest : public testing::WithParamInterface<SearchSortedLayerTestParams>, | ||
public ov::test::SubgraphBaseTest { | ||
public: | ||
static std::string getTestCaseName(testing::TestParamInfo<SearchSortedLayerTestParams> obj); | ||
static const std::vector<SearchSortedSpecificParams> GenerateParams(); | ||
|
||
protected: | ||
void SetUp() override; | ||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override; | ||
}; | ||
|
||
extern const std::vector<SearchSortedSpecificParams> SearchSortedParamsVector; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is it used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
} // namespace test | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "shared_test_classes/single_op/search_sorted.hpp" | ||
|
||
#include "common_test_utils/ov_tensor_utils.hpp" | ||
#include "shared_test_classes/base/ov_subgraph.hpp" | ||
|
||
namespace ov { | ||
namespace test { | ||
|
||
static const int SEED = 7877; | ||
|
||
std::string SearchSortedLayerTest::getTestCaseName(testing::TestParamInfo<SearchSortedLayerTestParams> obj) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
SearchSortedLayerTestParams basicParamsSet; | ||
basicParamsSet = obj.param; | ||
|
||
SearchSortedSpecificParams searchSortedParams; | ||
|
||
ElementType inputPrecision; | ||
std::string targetDevice; | ||
std::tie(searchSortedParams, inputPrecision, targetDevice) = basicParamsSet; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to use reference to avoid extra copy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Every single test on CPU side I have checked is doing it this way - so will keep it to be that way to be consistent. |
||
|
||
InputShape sortedInputShape; | ||
InputShape valuesInputShape; | ||
bool right_mode; | ||
|
||
std::tie(sortedInputShape, valuesInputShape, right_mode) = searchSortedParams; | ||
|
||
std::ostringstream result; | ||
result << inputPrecision << "_IS="; | ||
result << ov::test::utils::partialShape2str({sortedInputShape.first}) << ","; | ||
result << ov::test::utils::partialShape2str({valuesInputShape.first}) << "_"; | ||
result << "TS="; | ||
result << "("; | ||
for (const auto& targetShape : sortedInputShape.second) { | ||
result << ov::test::utils::vec2str(targetShape) << "_"; | ||
} | ||
result << ", "; | ||
for (const auto& targetShape : valuesInputShape.second) { | ||
result << ov::test::utils::vec2str(targetShape) << "_"; | ||
} | ||
result << ")_"; | ||
result << "right_mode=" << right_mode; | ||
|
||
return result.str(); | ||
} | ||
|
||
void SearchSortedLayerTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) { | ||
inputs.clear(); | ||
const auto& funcInputs = function->inputs(); | ||
|
||
const auto dataPrecision = funcInputs[0].get_element_type(); | ||
|
||
auto sortedTensor = | ||
ov::test::utils::create_and_fill_tensor_unique_sequence(dataPrecision, targetInputStaticShapes[0], 0, 8, SEED); | ||
inputs.insert({funcInputs[0].get_node_shared_ptr(), sortedTensor}); | ||
|
||
auto valuesTensor = | ||
ov::test::utils::create_and_fill_tensor_unique_sequence(dataPrecision, targetInputStaticShapes[1], 0, 8, SEED); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no statement in specification that the second input should be sorted and unique. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point - fixed. |
||
|
||
inputs.insert({funcInputs[1].get_node_shared_ptr(), valuesTensor}); | ||
} | ||
|
||
void SearchSortedLayerTest::SetUp() { | ||
SearchSortedLayerTestParams basicParamsSet; | ||
basicParamsSet = this->GetParam(); | ||
|
||
SearchSortedSpecificParams searchSortedParams; | ||
|
||
ElementType inputPrecision; | ||
std::tie(searchSortedParams, inputPrecision, targetDevice) = basicParamsSet; | ||
|
||
InputShape sortedInputShape; | ||
InputShape valuesInputShape; | ||
bool right_mode; | ||
std::tie(sortedInputShape, valuesInputShape, right_mode) = searchSortedParams; | ||
|
||
init_input_shapes({sortedInputShape, valuesInputShape}); | ||
auto sortedParam = std::make_shared<ov::op::v0::Parameter>(inputPrecision, inputDynamicShapes[0]); | ||
auto valuesParam = std::make_shared<ov::op::v0::Parameter>(inputPrecision, inputDynamicShapes[1]); | ||
|
||
auto op = std::make_shared<ov::op::v15::SearchSorted>(sortedParam, valuesParam, right_mode); | ||
|
||
ov::ParameterVector params{sortedParam, valuesParam}; | ||
function = std::make_shared<ov::Model>(op->outputs(), params, "SearchSorted"); | ||
} | ||
|
||
const std::vector<SearchSortedSpecificParams> SearchSortedLayerTest::GenerateParams() { | ||
const std::vector<SearchSortedSpecificParams> params = { | ||
SearchSortedSpecificParams{InputShape{{}, {{1, 18, 104}}}, InputShape{{}, {{1, 18, 104}}}, true}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add more shapes instances to check dynamic case. Like {{}, {{1, 18, 104}, {3, 50, 70}, ...}} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added more tests cases. |
||
SearchSortedSpecificParams{InputShape{{}, {{50}}}, InputShape{{1, -1, 10}, {{1, 18, 10}}}, false}, | ||
}; | ||
|
||
return params; | ||
} | ||
|
||
} // namespace test | ||
} // namespace ov |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix alignment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.