Skip to content

Commit

Permalink
Implement CPU plugin just-in-time emitter for Mod operation (#24574)
Browse files Browse the repository at this point in the history
Closes #24061
### Details:
 - Mod operation

### Tickets:
 - [CVS-137689](https://jira.devtools.intel.com/browse/CVS-137689)
  • Loading branch information
awayzjj authored May 23, 2024
1 parent 00ef5c9 commit d0e6f8f
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,48 @@ std::set<std::vector<element::Type>> jit_mish_emitter::get_supported_precisions(
return {{element::f32}};
}

/// MOD ///
jit_mod_emitter::jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
}

jit_mod_emitter::jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc): jit_emitter(host, host_isa, exec_prc) {
}

size_t jit_mod_emitter::get_inputs_count() const { return 2; }

void jit_mod_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_mod_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;

TReg divend = TReg(in_vec_idxs[0]);
TReg divisor = TReg(in_vec_idxs[1]);
TReg r = TReg(out_vec_idxs[0]);

h->uni_fdiv(r.s, divend.s, divisor.s);
h->frintz(r.s, r.s);
h->uni_fmul(r.s, r.s, divisor.s);
h->uni_fsub(r.s, divend.s, r.s);
}

std::set<std::vector<element::Type>> jit_mod_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

/// MUL_ADD ///
jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,27 @@ class jit_mish_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_mod_emitter : public jit_emitter {
public:
jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_mul_add_emitter : public jit_emitter {
public:
jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseMaximum,
Algorithm::EltwiseMinimum,
Algorithm::EltwiseMish,
Algorithm::EltwiseMod,
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Algorithm::EltwisePowerStatic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseMinimum, ov::intel_cpu::aarch64::jit_minimum_emitter),
OV_CASE(Algorithm::EltwiseMish, ov::intel_cpu::aarch64::jit_mish_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMod, ov::intel_cpu::aarch64::jit_mod_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter),
OV_CASE(Algorithm::EltwisePrelu, ov::intel_cpu::aarch64::jit_prelu_emitter),
Expand Down Expand Up @@ -806,6 +807,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter),
OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter),
OV_CASE(Algorithm::EltwiseMish, jit_mish_emitter),
OV_CASE(Algorithm::EltwiseMod, jit_mod_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,16 @@ std::string EltwiseLayerCPUTest::getPrimitiveType(const utils::EltwiseTypes& elt
if ((eltwise_type == utils::EltwiseTypes::ADD) ||
(eltwise_type == utils::EltwiseTypes::MULTIPLY) ||
(eltwise_type == utils::EltwiseTypes::SUBTRACT) ||
(eltwise_type == utils::EltwiseTypes::DIVIDE)) {
(eltwise_type == utils::EltwiseTypes::DIVIDE) ||
(eltwise_type == utils::EltwiseTypes::MOD)) {
return "jit";
}
#endif
return "acl";
if (eltwise_type == utils::EltwiseTypes::MOD) {
return "ref";
} else {
return "acl";
}
#else
return CPUTestsBase::getPrimitiveType();
#endif
Expand Down Expand Up @@ -304,6 +309,7 @@ const std::vector<utils::EltwiseTypes>& eltwiseOpTypesBinInp() {
utils::EltwiseTypes::FLOOR_MOD, // TODO: Fix CVS-111875
#endif
utils::EltwiseTypes::SQUARED_DIFF,
utils::EltwiseTypes::MOD,
};
return eltwiseOpTypesBinInp;
}
Expand Down

0 comments on commit d0e6f8f

Please sign in to comment.