Skip to content

Commit

Permalink
[NPUW] extend DQ & PMM processing and make reduceSum not to keep axis (
Browse files Browse the repository at this point in the history
…openvinotoolkit#26779)

### Details:
 - extend DQ and PMM to support more patterns. e.g. fp16 matmuls
- Make reduceSum not to keep axis because then it will convert to
poolings in compiler. Otherwise reduceSum will convert to the
convolution which is less efficient than poolings.

### Tickets:
 - E-140570

---------

Co-authored-by: Dmitry Matveev <dmitry.matveev@intel.com>
  • Loading branch information
shaojun-yao and dmatveev authored Oct 15, 2024
1 parent d52cf4a commit c6801aa
Showing 1 changed file with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input()});
auto qcvtr = opp::wrap_type<ov::op::v0::Convert>({qreshp});
auto qcvtr = opp::optional<ov::op::v0::Convert>({qreshp->output(0)});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtr});

Expand Down Expand Up @@ -409,13 +409,18 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
auto rshp_ccat = std::make_shared<ov::op::v1::Reshape>(scaled, rshp_ccat_c, false);

auto reduce_axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1);
auto reduce = std::make_shared<ov::op::v1::ReduceSum>(rshp_ccat, reduce_axis, true);
// Make reduceSum not to keep axis because then it will convert to poolings in compiler.
// Otherwise reduceSum will convert to the convolution which is less efficient than poolings.
auto reduce = std::make_shared<ov::op::v1::ReduceSum>(rshp_ccat, reduce_axis, false);

auto rshp_out_c = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, out_shape);
auto rshp_out = std::make_shared<ov::op::v1::Reshape>(reduce, rshp_out_c, false);

// Convert the result to f32 to maintain the graph contracts. FIXME should be avoided
auto out = std::make_shared<ov::op::v0::Convert>(rshp_out, ov::element::f32);
// Convert the result to f32 to maintain the graph contracts if required.
std::shared_ptr<ov::Node> out = rshp_out;
if (matched_matmul->get_element_type() == ov::element::f32) {
out = std::make_shared<ov::op::v0::Convert>(rshp_out, ov::element::f32);
}

// Now.. Reconnect the matmul readers to the new output (reducesum)
for (auto&& r : matched_matmul->output(0).get_target_inputs()) {
Expand Down Expand Up @@ -752,7 +757,7 @@ void mergeParallelMatMuls(const std::shared_ptr<ov::Model>& m, Context& ctx) {
auto new_cvt = std::make_shared<ov::op::v0::Convert>(new_w, new_s->get_element_type());

std::shared_ptr<ov::Node> new_mul = std::make_shared<ov::op::v1::Multiply>(new_cvt, new_s);
if (new_s->get_element_type() == ov::element::f16) {
if ((new_s->get_element_type() == ov::element::f16) && (orig_multiply.get_element_type() == ov::element::f32)) {
new_mul = std::make_shared<ov::op::v0::Convert>(new_mul, ov::element::f32);
}
auto new_w_shape = new_w->get_shape();
Expand Down

0 comments on commit c6801aa

Please sign in to comment.