-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CPU] Introduce FullyConnectedQuantized op and bias fusing
- Loading branch information
1 parent
102e875
commit 0250c68
Showing
49 changed files
with
1,383 additions
and
565 deletions.
There are no files selected for viewing
44 changes: 44 additions & 0 deletions
44
src/common/transformations/include/ov_ops/fully_connected.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API FullyConnected : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("FullyConnected", "ie_internal_opset"); | ||
|
||
FullyConnected() = default; | ||
|
||
FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::Output<Node>& bias, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
virtual std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const; | ||
|
||
ov::element::Type get_output_type() const { | ||
return m_output_type; | ||
} | ||
|
||
protected: | ||
ov::element::Type m_output_type; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
77 changes: 77 additions & 0 deletions
77
src/common/transformations/include/ov_ops/fully_connected_quantized.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "ov_ops/fully_connected.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API FullyConnectedQuantized : public ov::op::internal::FullyConnected { | ||
public: | ||
OPENVINO_OP("FullyConnectedQuantized", "gpu_opset"); | ||
|
||
FullyConnectedQuantized() = default; | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::Output<Node>& weight_zero_points, | ||
const ov::Output<Node>& input_scales, | ||
const ov::Output<Node>& input_zero_points, | ||
const ov::Output<Node>& output_scales, | ||
const ov::Output<Node>& output_zero_points, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::Output<Node>& weight_zero_points, | ||
const ov::Output<Node>& input_scales, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::Output<Node>& weight_zero_points, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
FullyConnectedQuantized(const ov::Output<Node>& X, | ||
const ov::Output<Node>& W, | ||
const ov::Output<Node>& bias, | ||
const ov::Output<Node>& weight_scales, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
// FullyConnectedQuantized(const ov::Output<Node>& X, | ||
// const ov::Output<Node>& W, | ||
// const ov::Output<Node>& bias, | ||
// const ov::Output<Node>& weight_scales, | ||
// const ov::Output<Node>& weight_zero_points, | ||
// const ov::Output<Node>& input_scales, | ||
// const ov::Output<Node>& input_zero_points, | ||
// const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const override final; | ||
|
||
ov::element::Type get_output_type() const { | ||
return m_output_type; | ||
} | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API Placeholder : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("Placeholder", "ie_internal_opset"); | ||
|
||
Placeholder(); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
22 changes: 22 additions & 0 deletions
22
...ommon/transformations/include/transformations/op_conversions/convert_fc_to_compressed.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedCompressed; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::ConvertFullyConnectedToFullyConnectedCompressed : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedCompressed", "0"); | ||
ConvertFullyConnectedToFullyConnectedCompressed(bool convert_u4zp_to_u8 = false); | ||
}; |
22 changes: 22 additions & 0 deletions
22
...common/transformations/include/transformations/op_conversions/convert_fc_to_quantized.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedQuantized; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::ConvertFullyConnectedToFullyConnectedQuantized : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedQuantized", "0"); | ||
ConvertFullyConnectedToFullyConnectedQuantized(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/fully_connected.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "matmul_shape_inference.hpp" | ||
#include "ov_ops/placeholder.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
FullyConnected::FullyConnected(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::Output<Node>& bias, | ||
const ov::element::Type output_type) | ||
: Op({A, B, bias}), | ||
m_output_type(output_type) { | ||
validate_and_infer_types(); | ||
} | ||
|
||
std::shared_ptr<ov::Node> FullyConnected::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
check_new_args_count(this, new_args); | ||
|
||
return std::make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type); | ||
} | ||
|
||
std::shared_ptr<Node> FullyConnected::fuse_bias(const ov::Output<Node>& bias) const { | ||
return std::make_shared<FullyConnected>(input_value(0), input_value(1), bias, m_output_type); | ||
} | ||
|
||
void FullyConnected::validate_and_infer_types() { | ||
const auto input_size = get_input_size(); | ||
NODE_VALIDATION_CHECK(this, | ||
input_size >= 3, | ||
"Number of inputs is incorrect. Current value is: ", | ||
input_size, | ||
", expected at least 3."); | ||
|
||
ov::op::v0::MatMul op; | ||
op.set_transpose_a(false); | ||
op.set_transpose_b(true); | ||
|
||
auto out_shapes = | ||
ov::op::v0::shape_infer(&op, | ||
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)}); | ||
|
||
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; | ||
set_output_type(0, output_type, out_shapes[0]); | ||
} | ||
|
||
bool FullyConnected::visit_attributes(ov::AttributeVisitor& visitor) { | ||
visitor.on_attribute("output_type", m_output_type); | ||
return true; | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
Oops, something went wrong.