Skip to content

Commit

Permalink
[CPU] Enable dnnl matmul primitive executor for FullyConnected node (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#25416)

to be able to run experiments, since long term goal is to completely
replace
oneDNN inner_product with matmul

### Todo:
 - [x] disable the executor
 - [x] uncomment CheckPluginRelatedResults in tests

### Tickets:
 - *ticket-id*
  • Loading branch information
EgorDuplensky authored Aug 12, 2024
1 parent 47656c8 commit f3546c7
Show file tree
Hide file tree
Showing 13 changed files with 621 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "memory_desc/dnnl_memory_desc.h"
#include "nodes/executors/convolution_config.hpp"
#include "nodes/executors/dnnl/dnnl_aliases.hpp"
#include "nodes/executors/dnnl/dnnl_fullyconnected_primitive.hpp"
#include "nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp"
#include "nodes/executors/executor.hpp"
#include "nodes/executors/fullyconnected_config.hpp"
Expand Down Expand Up @@ -90,12 +91,12 @@ static dnnl::convolution_forward::primitive_desc createDescriptorInternal(const

// @todo create general mapping from node configuration to backend configuration
static const std::map<memory::data_type, memory::data_type> weightsTypeByInputType{
// input data type weights data type
{memory::data_type::f32, memory::data_type::f32 },
{memory::data_type::f16, memory::data_type::f16 },
// input data type weights data type
{memory::data_type::f32, memory::data_type::f32},
{memory::data_type::f16, memory::data_type::f16},
{memory::data_type::bf16, memory::data_type::bf16},
{memory::data_type::u8, memory::data_type::s8 },
{memory::data_type::s8, memory::data_type::s8 },
{memory::data_type::u8, memory::data_type::s8},
{memory::data_type::s8, memory::data_type::s8},
};

// make a fake shape: OC, IC, 1
Expand Down Expand Up @@ -156,15 +157,8 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const ConvAttrs& attrs,
one_of(srcDesc->getPrecision(), ov::element::u8, ov::element::i8) && weiDesc->getPrecision() == ov::element::i8;
auto outputDataType = DnnlExtensionUtils::ElementTypeToDataType(dstDesc->getPrecision());

DnnlPostOpsComposer dnnlpoc(postOps,
context->getEngine(),
dims,
1,
isINT8,
1 << 0,
{},
attrs.withBias,
outputDataType);
DnnlPostOpsComposer
dnnlpoc(postOps, context->getEngine(), dims, 1, isINT8, 1 << 0, {}, attrs.withBias, outputDataType);

return dnnlpoc.compose();
}
Expand Down Expand Up @@ -210,6 +204,12 @@ std::shared_ptr<DnnlConvolutionPrimitive> DnnlConvolutionPrimitive::create(
return primitive;
}

DnnlMemoryDescPtr DnnlConvolutionPrimitive::makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc,
const DnnlMemoryDescPtr dstDesc,
bool weightsNonTransposed) {
return DnnlFCPrimitive::makeTransposedWeightDescriptor(srcDesc, dstDesc, weightsNonTransposed);
}

DnnlConvolutionPrimitive::DnnlConvolutionPrimitive(const Key& key,
const dnnl::engine& engine,
const std::vector<impl_desc_type>& implPriorities)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class DnnlConvolutionPrimitive {
return m_implType;
}

static DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc,
const DnnlMemoryDescPtr dstDesc,
bool weightsNonTransposed);

// create shape agnostic data using FC attributes (1x1 Convolution as FC executor)
static DnnlShapeAgnosticDataPtr createShapeAgnosticData(const FCAttrs& attrs,
const PostOps& postOps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ class DnnlFCExecutor : public Executor {
if (currentPrimitive && currentPrimitive->weightsDesc()->isCompatible(*newPrimMemDesc))
return;

if (m_attrs.weightsNonTransposed)
originalMemDesc = utils::makeTransposedWeightDescriptor(originalMemDesc, newPrimMemDesc);
originalMemDesc = Primitive::makeTransposedWeightDescriptor(originalMemDesc, newPrimMemDesc, m_attrs.weightsNonTransposed);

const auto weiMemory = utils::prepareWeightsMemory(originalMemDesc, newPrimMemDesc, memory, m_context, true);
m_primArgs[DNNL_ARG_WEIGHTS] = weiMemory->getPrimitive();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <common/primitive_attr.hpp>
#include <common/primitive_desc_iface.hpp>
#include <common/primitive_iface.hpp>
#include <iostream>
#include <memory>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_common.hpp>
Expand All @@ -24,6 +23,7 @@
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "memory_desc/dnnl_memory_desc.h"
#include "nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp"
#include "nodes/executors/dnnl/dnnl_utils.hpp"
#include "nodes/executors/executor.hpp"
#include "nodes/executors/fullyconnected_config.hpp"
#include "nodes/executors/memory_arguments.hpp"
Expand Down Expand Up @@ -51,7 +51,6 @@ size_t DnnlFCPrimitive::Key::hash() const {

seed = hash_combine(seed, get_attr_hash(*attr.get()));
seed = hash_combine(seed, sparseWeights);
seed = hash_combine(seed, transposedWeights);
seed = hash_combine(seed, modelType);

return seed;
Expand All @@ -73,8 +72,9 @@ bool DnnlFCPrimitive::Key::operator==(const Key& rhs) const {
result = result && dst && rhs.dst && dst->getDnnlDesc() == rhs.dst->getDnnlDesc();
}

result = result && *attr.get() == *rhs.attr.get() && sparseWeights == rhs.sparseWeights &&
transposedWeights == rhs.transposedWeights && modelType == rhs.modelType;
result = result && *attr.get() == *rhs.attr.get() &&
sparseWeights == rhs.sparseWeights &&
modelType == rhs.modelType;

return result;
}
Expand All @@ -88,16 +88,13 @@ std::shared_ptr<DnnlFCPrimitive> DnnlFCPrimitive::create(const MemoryArgs& memor
const auto& biaDesc = MemoryDescUtils::convertToDnnlMemoryDesc(memory.at(ARG_BIAS)->getDescPtr());
const auto& dstDesc = MemoryDescUtils::convertToDnnlMemoryDesc(memory.at(ARG_DST)->getDescPtr());

Key dnnlFCKey{
srcDesc,
weiDesc,
biaDesc,
dstDesc,
shapeAgnosticData->primAttrs.attr,
attrs.sparseWeights,
attrs.weightsNonTransposed,
attrs.modelType
};
Key dnnlFCKey{srcDesc,
weiDesc,
biaDesc,
dstDesc,
shapeAgnosticData->primAttrs.attr,
attrs.sparseWeights,
attrs.modelType};

auto builder = [&context](const Key& dnnlKey) {
return std::make_shared<DnnlFCPrimitive>(dnnlKey, context->getEngine(), context->getImplPriorities());
Expand All @@ -111,6 +108,20 @@ std::shared_ptr<DnnlFCPrimitive> DnnlFCPrimitive::create(const MemoryArgs& memor
return primitive;
}

DnnlMemoryDescPtr DnnlFCPrimitive::makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc,
const DnnlMemoryDescPtr dstDesc,
bool weightsNonTransposed) {
if (!weightsNonTransposed)
return srcDesc;

const auto& weiDesc = srcDesc->getDnnlDesc();
const auto reorderedWeiDesc =
dnnl::memory::desc{weiDesc.get_dims(), weiDesc.get_data_type(), dnnl::memory::format_tag::ba};
const auto transposedWeiDesc = reorderedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims());

return DnnlExtensionUtils::makeDescriptor(transposedWeiDesc);
}

bool DnnlFCPrimitive::useWeightsDecompressionImpl(const ov::element::Type inputType,
const ov::element::Type weightsType,
const ov::intel_cpu::Config::ModelType modelType) {
Expand All @@ -129,8 +140,12 @@ bool DnnlFCPrimitive::useWeightsDecompressionImpl(const ov::element::Type inputT
return false;
}

bool DnnlFCPrimitive::useDynamicQuantizationImpl(size_t dqGroupSize, const MemoryDescPtr srcDesc, const MemoryDescPtr weightsDesc,
MemoryCPtr scalesPtr, MemoryCPtr zpPtr, bool needTranspose) {
bool DnnlFCPrimitive::useDynamicQuantizationImpl(size_t dqGroupSize,
const MemoryDescPtr srcDesc,
const MemoryDescPtr weightsDesc,
MemoryCPtr scalesPtr,
MemoryCPtr zpPtr,
bool needTranspose) {
if (dqGroupSize == 0)
return false;

Expand Down Expand Up @@ -232,7 +247,9 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs,
uint8_t zp_value = (wei_precision == ov::element::i8) ? 128 : 8;
DnnlBlockedMemoryDesc zpMemoryDesc(ov::element::u8, Shape({1}));
auto decompressionSubtractPtr = std::make_shared<Memory>(context->getEngine(), zpMemoryDesc, &zp_value);
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr, !attrs.weightsNonTransposed, ov::element::u8);
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr,
!attrs.weightsNonTransposed,
ov::element::u8);
}
dnnlpoc.setDynamicQuantizationParams(attrs.dynamicQuantizationGroupSize);
}
Expand Down Expand Up @@ -364,10 +381,15 @@ DnnlShapeAgnosticDataPtr DnnlFCPrimitive::createShapeAgnosticData(const FCAttrs&
const auto& biasDesc = memory.at(ARG_BIAS)->getDescPtr();
auto dstDesc = memory.at(ARG_DST)->getDescPtr();

const auto useWeightsDecompression = useWeightsDecompressionImpl(srcDesc->getPrecision(), weiDesc->getPrecision(), attrs.modelType);
const auto useDynamicQuantization = useWeightsDecompression &&
useDynamicQuantizationImpl(attrs.dynamicQuantizationGroupSize, srcDesc, weiDesc,
attrs.decompressionMultiplyPtr, attrs.decompressionSubtractPtr, !attrs.weightsNonTransposed);
const auto useWeightsDecompression =
useWeightsDecompressionImpl(srcDesc->getPrecision(), weiDesc->getPrecision(), attrs.modelType);
const auto useDynamicQuantization =
useWeightsDecompression && useDynamicQuantizationImpl(attrs.dynamicQuantizationGroupSize,
srcDesc,
weiDesc,
attrs.decompressionMultiplyPtr,
attrs.decompressionSubtractPtr,
!attrs.weightsNonTransposed);

const auto postOpData = createPrimitiveAttrs(attrs, postOps, memory, context, useDynamicQuantization);

Expand Down Expand Up @@ -402,8 +424,8 @@ DnnlShapeAgnosticDataPtr DnnlFCPrimitive::createShapeAgnosticData(const FCAttrs&

const auto weightsDesc = DnnlExtensionUtils::makeDescriptor(primDesc.weights_desc());
auto originalWeightsDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weiDesc);
if (attrs.weightsNonTransposed)
originalWeightsDesc = utils::makeTransposedWeightDescriptor(originalWeightsDesc, weightsDesc);

originalWeightsDesc = makeTransposedWeightDescriptor(originalWeightsDesc, weightsDesc, attrs.weightsNonTransposed);

// ignore the result since we just need to put the packed weights into the cache
(void)utils::prepareWeightsMemory(originalWeightsDesc,
Expand All @@ -429,15 +451,16 @@ DnnlFCPrimitive::DnnlFCPrimitive(const Key& key,
const dnnl::engine& engine,
const std::vector<impl_desc_type>& implPriorities)
: m_stream(dnnl::stream(engine)),
m_primDesc(createPrimitiveDesc(key.src->getDnnlDesc(),
key.wei->getDnnlDesc(),
key.bias->getDnnlDesc(),
key.dst->getDnnlDesc(),
key.attr,
engine,
implPriorities,
key.sparseWeights,
useWeightsDecompressionImpl(key.src->getPrecision(), key.wei->getPrecision(), key.modelType))),
m_primDesc(createPrimitiveDesc(
key.src->getDnnlDesc(),
key.wei->getDnnlDesc(),
key.bias->getDnnlDesc(),
key.dst->getDnnlDesc(),
key.attr,
engine,
implPriorities,
key.sparseWeights,
useWeightsDecompressionImpl(key.src->getPrecision(), key.wei->getPrecision(), key.modelType))),
m_implType(implTypeFromPrimDesc(m_primDesc)),
m_srcDesc(DnnlExtensionUtils::makeDescriptor(m_primDesc.src_desc())),
m_weiDesc(DnnlExtensionUtils::makeDescriptor(m_primDesc.weights_desc())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "cpu_memory.h"
#include "memory_desc/dnnl_memory_desc.h"
#include "nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp"
#include "nodes/executors/dnnl/dnnl_utils.hpp"
#include "nodes/executors/executor.hpp"
#include "nodes/executors/fullyconnected_config.hpp"

Expand All @@ -25,7 +24,6 @@ class DnnlFCPrimitive {
DnnlMemoryDescCPtr dst;
dnnl::primitive_attr attr;
bool sparseWeights;
bool transposedWeights;
Config::ModelType modelType;

size_t hash() const;
Expand Down Expand Up @@ -63,16 +61,26 @@ class DnnlFCPrimitive {
const ExecutorContext::CPtr context,
const bool cacheWeights);

static bool useWeightsDecompressionImpl(const ov::element::Type inputType, const ov::element::Type weightsType, const Config::ModelType modelType);
static bool useWeightsDecompressionImpl(const ov::element::Type inputType,
const ov::element::Type weightsType,
const Config::ModelType modelType);

static DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc,
const DnnlMemoryDescPtr dstDesc,
bool weightsNonTransposed);

static std::shared_ptr<DnnlFCPrimitive> create(const MemoryArgs& memory,
const FCAttrs& attrs,
const ExecutorContext::CPtr context,
const DnnlShapeAgnosticDataPtr& shapeAgnosticData);

private:
static bool useDynamicQuantizationImpl(size_t dqGroupSize, const MemoryDescPtr srcDesc, const MemoryDescPtr weightsDesc,
MemoryCPtr scalesPtr, MemoryCPtr zpPtr, bool needTranspose);
static bool useDynamicQuantizationImpl(size_t dqGroupSize,
const MemoryDescPtr srcDesc,
const MemoryDescPtr weightsDesc,
MemoryCPtr scalesPtr,
MemoryCPtr zpPtr,
bool needTranspose);

dnnl::stream m_stream;
dnnl::primitive_desc m_primDesc;
Expand Down
Loading

0 comments on commit f3546c7

Please sign in to comment.