Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transformations] Add SliceScatter-15 decomposition transformation #27136

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertSliceScatter;

} // namespace pass
} // namespace ov

class ov::pass::ConvertSliceScatter : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertSliceScatter", "0");
ConvertSliceScatter();
};
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
#include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp"
#include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
Expand Down Expand Up @@ -233,6 +234,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertEmbeddingBagOffsets15ToEmbeddingBagOffsetsSum3)
REGISTER_PASS(manager, ConvertEmbeddingBagPacked15ToEmbeddingBagPackedSum3)
REGISTER_PASS(manager, ConvertScatterNDUpdate15ToScatterNDUpdate3)
REGISTER_PASS(manager, ConvertSliceScatter)

auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_slicescatter.hpp"

#include <memory>
#include <vector>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/slice_scatter.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::ConvertSliceScatter::ConvertSliceScatter() {
MATCHER_SCOPE(ConvertSliceScatter);

auto slicescatter = pattern::wrap_type<ov::op::v15::SliceScatter>();

matcher_pass_callback callback = [](pattern::Matcher& m) {
auto slice_node = ov::as_type_ptr<ov::op::v15::SliceScatter>(m.get_match_root());
if (!slice_node) {
return false;
}
NodeRegistry node_registry;
auto const_0 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 0);
auto const_1 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 1);
auto const_1d_neg_1 =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{1}, std::vector<int64_t>{-1});
auto const_scatter_indices_shape =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{-1, 1});
auto data_shape = node_registry.make<ov::op::v3::ShapeOf>(slice_node->input_value(0), ov::element::i64);
auto num_elements_data = node_registry.make<ov::op::v1::ReduceProd>(data_shape, const_0, false);
auto data_indices_flatten =
node_registry.make<ov::op::v4::Range>(const_0, num_elements_data, const_1, ov::element::i64);
auto full_data_indices = node_registry.make<ov::op::v1::Reshape>(data_indices_flatten, data_shape, false);
std::shared_ptr<ov::op::v8::Slice> slice_indices;
if (slice_node->get_input_size() == 5) {
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices,
slice_node->input_value(2),
slice_node->input_value(3),
slice_node->input_value(4));
} else {
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices,
slice_node->input_value(2),
slice_node->input_value(3),
slice_node->input_value(4),
slice_node->input_value(5));
}
auto slice_indices_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_indices, const_scatter_indices_shape, false);
auto updates_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(1), const_1d_neg_1, false);
auto data_flatten = node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(0), const_1d_neg_1, false);
auto output_flatten =
node_registry.make<ov::op::v3::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
auto output = node_registry.make<ov::op::v1::Reshape>(output_flatten, data_shape, false);

output->set_friendly_name(slice_node->get_friendly_name());
copy_runtime_info(slice_node, node_registry.get());
replace_node(slice_node, output);
slice_node->clear_control_dependencies();

return true;
};

auto m = std::make_shared<pattern::Matcher>(slicescatter, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <memory>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/opsets/opset15.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/utils/utils.hpp"
using namespace testing;

namespace {

std::shared_ptr<ov::Model> create_v15_model(bool with_axes) {
const auto data = std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{256, 10, 15});
const auto updates = std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{4, 7, 2});
const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0});
const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2});
const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1});
const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2});
std::shared_ptr<ov::opset15::SliceScatter> slicescatter;
if (!with_axes) {
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step);
} else {
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step, axes);
}
slicescatter->set_friendly_name("slicescatter15");
return std::make_shared<ov::Model>(slicescatter->outputs(), ov::ParameterVector{data, updates});
}

std::shared_ptr<ov::Model> create_decomposed_model(bool with_axes) {
const auto data = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::Shape{256, 10, 15});
const auto updates = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::Shape{4, 7, 2});
const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0});
const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2});
const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1});
const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {}, {1});
auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto scatter_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {-1, 1});
auto data_shape = std::make_shared<ov::opset8::ShapeOf>(data, ov::element::i64);
auto num_elements_data = std::make_shared<ov::opset8::ReduceProd>(data_shape, zero, false);
auto data_indices_flattened = std::make_shared<ov::opset8::Range>(zero, num_elements_data, one, ov::element::i64);
auto full_data_indices = std::make_shared<ov::opset8::Reshape>(data_indices_flattened, data_shape, false);
std::shared_ptr<ov::opset8::Slice> slice_indices;
if (!with_axes) {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step);
} else {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step, axes);
}
auto slice_indices_flatten = std::make_shared<ov::opset8::Reshape>(slice_indices, scatter_shape, false);
auto updates_flatten = std::make_shared<ov::opset8::Reshape>(updates, neg_one_1d, false);
auto data_flatten = std::make_shared<ov::opset8::Reshape>(data, neg_one_1d, false);
auto output_flatten =
std::make_shared<ov::opset8::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
auto slicescatter = std::make_shared<ov::opset8::Reshape>(output_flatten, data_shape, false);
slicescatter->set_friendly_name("slicescatter15");

return std::make_shared<ov::Model>(slicescatter->outputs(), ov::ParameterVector{data, updates});
}

} // namespace

TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_axes) {
manager.register_pass<ov::pass::ConvertSliceScatter>();
model = create_v15_model(true);
model_ref = create_decomposed_model(true);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_no_axes) {
manager.register_pass<ov::pass::ConvertSliceScatter>();
model = create_v15_model(false);
model_ref = create_decomposed_model(false);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp"
#include "transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp"
#include "transformations/op_conversions/convert_shuffle_channels3.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/op_conversions/convert_space_to_batch.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
Expand Down Expand Up @@ -644,6 +645,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertScatterNDUpdate15ToScatterNDUpdate3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertSliceScatter);
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down
Loading