Skip to content

Commit

Permalink
[ARM CPU] Avg Pooling, ROI Pooling fix for fp16 precision (openvinoto…
Browse files Browse the repository at this point in the history
  • Loading branch information
allnes authored Oct 27, 2023
1 parent 4e41678 commit 9decbb5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/plugins/intel_cpu/src/nodes/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,15 @@ void Pooling::getSupportedDescriptors() {

// WA: LPT transformation has WA which allows average pooling has I8/U8 output precision instead of FP32,
// so we explicitly set output precision as FP32
if (outputPrecision != Precision::I8 && inputPrecision != Precision::BF16) {
if (!one_of(outputPrecision, Precision::I8, Precision::BF16, Precision::FP16)) {
if (getAlgorithm() == Algorithm::PoolingMax) {
// oneDNN supports only equal precisions for input and output
outputPrecision = inputPrecision;
} else if (getAlgorithm() == Algorithm::PoolingAvg) {
outputPrecision = Precision::FP32;
}
}
if (inputPrecision == Precision::BF16) {
if (one_of(inputPrecision, Precision::BF16, Precision::FP16)) {
outputPrecision = inputPrecision;
}

Expand All @@ -351,7 +351,7 @@ void Pooling::getSupportedDescriptors() {

if (inputPrecision == Precision::I8 || inputPrecision == Precision::U8) {
// We have to extend i8i8_pooling_fwd_t from oneDNN to support BF16 output data type
if (outputDataType == memory::data_type::bf16)
if (one_of(outputDataType, memory::data_type::bf16, memory::data_type::f16))
outputDataType = memory::data_type::f32;
// i8 layers supports only ndhwc and nhwc layouts
const auto in_candidate = std::make_shared<DnnlBlockedMemoryDesc>(parentShape, inputDataType, inputRank == 3 ?
Expand Down
21 changes: 13 additions & 8 deletions src/plugins/intel_cpu/src/nodes/roi_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,6 @@ void ROIPooling::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

refParams.src_prc = getOriginalInputPrecisionAtPort(0);

if (!mayiuse(avx512_core)) {
if (refParams.src_prc == Precision::BF16)
refParams.src_prc = Precision::FP32;
}

auto format = mayiuse(avx512_core) ? LayoutType::nCsp16c : LayoutType::nCsp8c;
impl_desc_type impl_type;
if (mayiuse(cpu::x64::avx512_core)) {
Expand All @@ -453,6 +446,17 @@ void ROIPooling::initSupportedPrimitiveDescriptors() {
impl_type = impl_desc_type::ref;
}

refParams.src_prc = getOriginalInputPrecisionAtPort(0);

if (!mayiuse(avx512_core)) {
if (refParams.src_prc == Precision::BF16)
refParams.src_prc = Precision::FP32;
}

if (impl_type != impl_desc_type::ref && refParams.src_prc == Precision::FP16) {
refParams.src_prc = Precision::FP32;
}

addSupportedPrimDesc({{format, refParams.src_prc},
{LayoutType::ncsp, refParams.src_prc}},
{{format, refParams.src_prc}},
Expand Down Expand Up @@ -826,7 +830,8 @@ std::shared_ptr<ROIPooling::ROIPoolingExecutor> ROIPooling::ROIPoolingExecutor::

OV_SWITCH(intel_cpu, ROIPoolingExecutorCreation, ctx, jpp.src_prc,
OV_CASE(Precision::FP32, float),
OV_CASE(Precision::BF16, bfloat16_t))
OV_CASE(Precision::BF16, bfloat16_t),
OV_CASE(Precision::FP16, float16_t))

return ctx.executor;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,9 @@ std::vector<std::string> disabledTestPatterns() {

#if defined(OV_CPU_ARM_ENABLE_FP16)
// Issue: 123019
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_CeilRounding.*modelType=f16.*)");
retVector.emplace_back(R"(smoke_AvgPool_ExplicitPad_FloorRounding_5Dinput/PoolingLayerTest.*modelType=f16.*)");
retVector.emplace_back(R"(smoke_AvgPool_SameUpperPad_FloorRounding_5Dinput/PoolingLayerTest.*modelType=f16.*)");
retVector.emplace_back(R"(smoke_AvgPool_SameLowerPad_CeilRounding_5Dinput/PoolingLayerTest.*modelType=f16.*)");
retVector.emplace_back(R"(smoke_CompareWithRefs_Mvn.*INFERENCE_PRECISION_HINT=f16.*)");
retVector.emplace_back(R"(smoke_staticShapes4D.*INFERENCE_PRECISION_HINT=f16.*)");
retVector.emplace_back(R"(smoke_dynamicShapes4D.*INFERENCE_PRECISION_HINT=f16.*)");
// Issue: 123064
retVector.emplace_back(R"(smoke_TestsROIPooling_.*/ROIPoolingLayerTest.*modelType=f16.*)");
#endif

#endif
Expand Down

0 comments on commit 9decbb5

Please sign in to comment.