Skip to content

Commit

Permalink
[CPU] Introduce FullyConnectedQuantized op and bias fusing
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky committed Sep 9, 2024
1 parent 102e875 commit 0250c68
Show file tree
Hide file tree
Showing 49 changed files with 1,383 additions and 565 deletions.
44 changes: 44 additions & 0 deletions src/common/transformations/include/ov_ops/fully_connected.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/node.hpp"
#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {

class TRANSFORMATIONS_API FullyConnected : public ov::op::Op {
public:
OPENVINO_OP("FullyConnected", "ie_internal_opset");

FullyConnected() = default;

FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& bias,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

virtual std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const;

ov::element::Type get_output_type() const {
return m_output_type;
}

protected:
ov::element::Type m_output_type;
};

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/node.hpp"
#include "openvino/op/op.hpp"
#include "ov_ops/fully_connected.hpp"

namespace ov {
namespace op {
namespace internal {

class TRANSFORMATIONS_API FullyConnectedQuantized : public ov::op::internal::FullyConnected {
public:
OPENVINO_OP("FullyConnectedQuantized", "gpu_opset");

FullyConnectedQuantized() = default;

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::Output<Node>& input_scales,
const ov::Output<Node>& input_zero_points,
const ov::Output<Node>& output_scales,
const ov::Output<Node>& output_zero_points,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::Output<Node>& input_scales,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::element::Type output_type = ov::element::undefined);

// FullyConnectedQuantized(const ov::Output<Node>& X,
// const ov::Output<Node>& W,
// const ov::Output<Node>& bias,
// const ov::Output<Node>& weight_scales,
// const ov::Output<Node>& weight_zero_points,
// const ov::Output<Node>& input_scales,
// const ov::Output<Node>& input_zero_points,
// const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const override final;

ov::element::Type get_output_type() const {
return m_output_type;
}
};

} // namespace internal
} // namespace op
} // namespace ov
27 changes: 27 additions & 0 deletions src/common/transformations/include/ov_ops/placeholder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {

class TRANSFORMATIONS_API Placeholder : public ov::op::Op {
public:
OPENVINO_OP("Placeholder", "ie_internal_opset");

Placeholder();

bool visit_attributes(ov::AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
};

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedCompressed;

} // namespace pass
} // namespace ov

class ov::pass::ConvertFullyConnectedToFullyConnectedCompressed : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedCompressed", "0");
ConvertFullyConnectedToFullyConnectedCompressed(bool convert_u4zp_to_u8 = false);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedQuantized;

} // namespace pass
} // namespace ov

class ov::pass::ConvertFullyConnectedToFullyConnectedQuantized : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedQuantized", "0");
ConvertFullyConnectedToFullyConnectedQuantized();
};
62 changes: 62 additions & 0 deletions src/common/transformations/src/ov_ops/fully_connected.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/fully_connected.hpp"

#include <memory>

#include "matmul_shape_inference.hpp"
#include "ov_ops/placeholder.hpp"

namespace ov {
namespace op {
namespace internal {

FullyConnected::FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& bias,
const ov::element::Type output_type)
: Op({A, B, bias}),
m_output_type(output_type) {
validate_and_infer_types();
}

std::shared_ptr<ov::Node> FullyConnected::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);

return std::make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
}

std::shared_ptr<Node> FullyConnected::fuse_bias(const ov::Output<Node>& bias) const {
return std::make_shared<FullyConnected>(input_value(0), input_value(1), bias, m_output_type);
}

void FullyConnected::validate_and_infer_types() {
const auto input_size = get_input_size();
NODE_VALIDATION_CHECK(this,
input_size >= 3,
"Number of inputs is incorrect. Current value is: ",
input_size,
", expected at least 3.");

ov::op::v0::MatMul op;
op.set_transpose_a(false);
op.set_transpose_b(true);

auto out_shapes =
ov::op::v0::shape_infer(&op,
std::vector<ov::PartialShape>{get_input_partial_shape(0), get_input_partial_shape(1)});

auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;
set_output_type(0, output_type, out_shapes[0]);
}

bool FullyConnected::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("output_type", m_output_type);
return true;
}

} // namespace internal
} // namespace op
} // namespace ov
Loading

0 comments on commit 0250c68

Please sign in to comment.