Skip to content

Commit

Permalink
[ARM CPU] Fix tests for eltwise layer (openvinotoolkit#16917)
Browse files Browse the repository at this point in the history
  • Loading branch information
allnes authored Apr 20, 2023
1 parent 5bded05 commit d00731c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 38 deletions.
112 changes: 75 additions & 37 deletions src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
namespace ov {
namespace intel_cpu {

using namespace InferenceEngine;

class AclEltwiseExecutor : public EltwiseExecutor {
public:
AclEltwiseExecutor(const ExecutorContext::CPtr context);
explicit AclEltwiseExecutor(const ExecutorContext::CPtr context);

bool init(const EltwiseAttrs& eltwiseAttrs,
const std::vector<MemoryDescPtr>& srcDescs,
Expand All @@ -39,62 +41,98 @@ class AclEltwiseExecutorBuilder : public EltwiseExecutorBuilder {
bool isSupported(const EltwiseAttrs& eltwiseAttrs,
const std::vector<MemoryDescPtr>& srcDescs,
const std::vector<MemoryDescPtr>& dstDescs) const override {
auto checkPrecision = [&srcDescs, &dstDescs](std::vector<Precision> srcVecPrc, Precision dstPrc) -> bool {
for (int i = 0; i < srcDescs.size(); i++) {
if (srcDescs[i]->getPrecision() != srcVecPrc[i]) return false;
}
if (dstDescs[0]->getPrecision() != dstPrc) { return false; }
return true;
};

switch (eltwiseAttrs.algorithm) {
case Algorithm::EltwiseAdd:
case Algorithm::EltwiseMultiply:
case Algorithm::EltwiseSubtract:
case Algorithm::EltwiseSqrt:
case Algorithm::EltwiseDivide:
case Algorithm::EltwiseMaximum:
case Algorithm::EltwiseMinimum:
case Algorithm::EltwiseSquaredDifference:
case Algorithm::EltwisePowerDynamic:
case Algorithm::EltwiseEqual:
case Algorithm::EltwiseNotEqual:
case Algorithm::EltwiseGreater:
case Algorithm::EltwiseGreaterEqual:
case Algorithm::EltwiseLess:
case Algorithm::EltwiseLessEqual:
case Algorithm::EltwiseRelu:
case Algorithm::EltwiseGeluErf:
case Algorithm::EltwiseElu:
case Algorithm::EltwiseTanh:
case Algorithm::EltwiseSigmoid:
case Algorithm::EltwiseAbs:
case Algorithm::EltwiseSqrt:
// case Algorithm::EltwisePowerDynamic: // TODO: ACL version doesn't work https://github.com/ARM-software/ComputeLibrary/issues/1047
case Algorithm::EltwiseSoftRelu:
case Algorithm::EltwiseExp:
case Algorithm::EltwiseClamp:
case Algorithm::EltwiseSwish:
case Algorithm::EltwisePrelu:
case Algorithm::EltwiseHswish:
if (!(checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
return false;
}
break;
case Algorithm::EltwiseAbs:
case Algorithm::EltwiseExp:
case Algorithm::EltwiseLog:
if (!(checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
return false;
}
break;
case Algorithm::EltwiseMaximum:
case Algorithm::EltwiseMinimum:
case Algorithm::EltwiseSquaredDifference:
if (!(checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
return false;
}
break;
case Algorithm::EltwiseAdd:
case Algorithm::EltwiseSubtract:
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
checkPrecision({Precision::I32, Precision::I32}, Precision::I32) ||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
return false;
}
break;
case Algorithm::EltwiseMultiply:
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
checkPrecision({Precision::U8, Precision::U8}, Precision::I16) ||
checkPrecision({Precision::U8, Precision::I16}, Precision::I16) ||
checkPrecision({Precision::I16, Precision::U8}, Precision::I16) ||
checkPrecision({Precision::I16, Precision::I16}, Precision::I16) ||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::FP16) ||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::FP32))) {
return false;
}
break;
// ACL supports only U8 precision on output for comparison operations
case Algorithm::EltwiseEqual:
case Algorithm::EltwiseNotEqual:
case Algorithm::EltwiseGreater:
case Algorithm::EltwiseGreaterEqual:
case Algorithm::EltwiseLess:
case Algorithm::EltwiseLessEqual:
if (!(checkPrecision({Precision::U8, Precision::U8}, Precision::U8) ||
checkPrecision({Precision::I16, Precision::I16}, Precision::U8) ||
checkPrecision({Precision::I32, Precision::I32}, Precision::U8) ||
checkPrecision({Precision::FP16, Precision::FP16}, Precision::U8) ||
checkPrecision({Precision::FP32, Precision::FP32}, Precision::U8))) {
return false;
}
break;
default:
return false;
}

// ACL supports only U8 precision on output for comparison operations
if (one_of(eltwiseAttrs.algorithm, Algorithm::EltwiseEqual, Algorithm::EltwiseNotEqual, Algorithm::EltwiseGreater,
Algorithm::EltwiseGreaterEqual, Algorithm::EltwiseLess, Algorithm::EltwiseLessEqual)) {
if (dstDescs[0]->getPrecision() != InferenceEngine::Precision::U8) {
for (const auto & srcDesc : srcDescs) {
if (getAclDataLayoutByMemoryDesc(srcDesc) == arm_compute::DataLayout::UNKNOWN)
return false;
}
}
for (const auto &srcD : srcDescs) {
for (const auto &dstD : dstDescs) {
if ((srcD->getPrecision() != InferenceEngine::Precision::FP32 &&
srcD->getPrecision() != InferenceEngine::Precision::FP16) ||
srcD->getPrecision() != dstD->getPrecision())
return false;
}
}

for (int i = 0; i < srcDescs.size(); i++) {
if (getAclDataLayoutByMemoryDesc(srcDescs[i]) == arm_compute::DataLayout::UNKNOWN)
return false;
}
for (int i = 0; i < dstDescs.size(); i++) {
if (getAclDataLayoutByMemoryDesc(dstDescs[i]) == arm_compute::DataLayout::UNKNOWN)
for (const auto & dstDesc : dstDescs) {
if (getAclDataLayoutByMemoryDesc(dstDesc) == arm_compute::DataLayout::UNKNOWN)
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckSecondaryPropertiesTest.*)");
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckWithSecondaryPropertiesDoubleTest.*)");
}
retVector.emplace_back(R"(smoke_Decomposition_(3|4)D/Mvn6LayerTest.CompareWithRefs.*)");
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_CeilRounding/PoolingLayerTest.CompareWithRefs.*)");
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
retVector.emplace_back(R"(MultipleLSTMCellTest/MultipleLSTMCellTest.CompareWithRefs.*)");
Expand Down

0 comments on commit d00731c

Please sign in to comment.