Skip to content

Commit

Permalink
[CPU][ARM] Reorder FullyConnected weights in scope of compile_model (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky authored and alvoron committed Nov 6, 2023
1 parent 0318012 commit b7ac034
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 121 deletions.
278 changes: 159 additions & 119 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "fake_quantize.h"
#include "input.h"
#include "memory_desc/blocked_memory_desc.h"
#include "memory_desc/dnnl_memory_desc.h"
#include "reorder.h"
#include "transformations/cpu_opset/common/op/fully_connected.hpp"
#include "ngraph/opsets/opset1.hpp"
Expand Down Expand Up @@ -329,6 +330,156 @@ void FullyConnected::prepackMLASWeight() {
}
#endif

static dnnl::convolution_forward::primitive_desc
createDescriptorInternalForConv(DnnlMemoryDescCPtr inputDescPtr,
DnnlMemoryDescCPtr weightDescPtr,
DnnlMemoryDescCPtr biasDescPtr,
DnnlMemoryDescCPtr outputDescPtr,
const dnnl::primitive_attr& attr,
const dnnl::engine& engine) {
const dnnl::memory::desc &inputDesc = inputDescPtr->getDnnlDesc();
const dnnl::memory::desc &outputDesc = outputDescPtr->getDnnlDesc();
const dnnl::memory::desc &weightDesc = weightDescPtr->getDnnlDesc();
// make a fake shape: N, IC, W
auto inDims = inputDesc.get_dims();
dnnl::memory::dims normalizedInDims;
if (inDims.size() == 3) {
normalizedInDims = {inDims[0], inDims[2], inDims[1]};
} else if (inDims.size() == 2) {
normalizedInDims = {dnnl::memory::dim{1}, inDims[1], inDims[0]};
}
auto convInDesc = dnnl::memory::desc(normalizedInDims, inputDesc.get_data_type(), memory::format_tag::nwc);

// make a fake shape: N, OC, W
const auto& outDims = outputDesc.get_dims();
dnnl::memory::dims normalizedOutDims;
if (outDims.size() == 3) {
normalizedOutDims = { outDims[0], outDims[2], outDims[1]};
} else if (outDims.size() == 2) {
normalizedOutDims = { dnnl::memory::dim{1}, outDims[1], outDims[0]};
}
auto convOutDesc = dnnl::memory::desc(normalizedOutDims, outputDesc.get_data_type(), memory::format_tag::nwc);

// make a fake shape: OC, IC, 1
auto weightDims = weightDesc.get_dims();
dnnl::memory::dims normalizedWeightDims;
normalizedWeightDims = {static_cast<dnnl::memory::dim>(weightDims[0]),
static_cast<dnnl::memory::dim>(weightDims[1]),
dnnl::memory::dim{1}};
auto convWeightDescAny = dnnl::memory::desc(normalizedWeightDims, weightDesc.get_data_type(), dnnl::memory::format_tag::any);

if (biasDescPtr) {
return dnnl::convolution_forward::primitive_desc(
engine,
prop_kind::forward_inference,
dnnl::algorithm::convolution_direct,
convInDesc, convWeightDescAny, biasDescPtr->getDnnlDesc(), convOutDesc,
dnnl::memory::dims{1}, // stride
dnnl::memory::dims{0}, // dilation
dnnl::memory::dims{0}, // paddingL
dnnl::memory::dims{0}, // paddingR
attr);
} else {
return dnnl::convolution_forward::primitive_desc(
engine,
prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
convInDesc, convWeightDescAny, convOutDesc,
dnnl::memory::dims{1}, // stride
dnnl::memory::dims{0}, // dilation
dnnl::memory::dims{0}, // paddingL
dnnl::memory::dims{0}, // paddingR
attr);
}
}

static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::engine& engine) {
// use conv1x1 primitive for computation
if (key.useConv1x1) {
auto prim_desc = createDescriptorInternalForConv(key.inp0, key.inp1, key.bias, key.out, key.attr, engine);
const bool found = DnnlExtensionUtils::find_implementation(prim_desc, brgconv_avx512_1x1);

if (found)
return std::move(prim_desc);
}

// fallback to normal inner product primitive
auto inDesc = key.inp0->getDnnlDesc();
const auto& inDims = inDesc.get_dims(); // @TODO query + copy might be slow
if (inDims.size() == 3) {
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
inDesc = inDesc.reshape(normalizedInDims);
}
auto outDesc = key.out->getDnnlDesc();
const auto& outDims = outDesc.get_dims(); // @TODO query + copy might be slow

if (outDims.size() == 3) {
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
outDesc = outDesc.reshape(normalizedOutDims);
}
auto wghDescAny = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()),
key.inp1->getDataType(), memory::format_tag::any);
dnnl::inner_product_forward::primitive_desc prim_desc;
if (key.bias) {
prim_desc = dnnl::inner_product_forward::primitive_desc(
engine,
dnnl::prop_kind::forward_inference,
inDesc,
wghDescAny,
key.bias->getDnnlDesc(),
outDesc,
key.attr);
} else {
prim_desc = dnnl::inner_product_forward::primitive_desc(
engine,
dnnl::prop_kind::forward_inference,
inDesc,
wghDescAny,
outDesc,
key.attr);
}
auto first_desc = dnnl::inner_product_forward::primitive_desc(prim_desc.get());
const bool found = DnnlExtensionUtils::find_implementation(prim_desc, key.implType);

if (found)
return std::move(prim_desc);

return std::move(first_desc);
}

#if defined(OV_CPU_WITH_ACL)
/**
* Do not wait till prepareParams to reorder the weights
* Do it in scope of compile_model using dummy shapes
*/
void FullyConnected::prepareWeightsUsingDummyShape() {
NodeDesc *selected_pd = getSelectedPrimitiveDescriptor();
if (selected_pd == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";

auto inDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtInputPort(DATA_ID)));
auto weightDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weightDescIP);
auto biasDesc = withBiases ? MemoryDescUtils::convertToDnnlMemoryDesc(getBaseMemDescAtInputPort(BIAS_ID)) : nullptr;
auto outDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtOutputPort(0)));

const FCKey key = {inDesc,
weightDesc,
biasDesc,
outDesc,
attr,
selected_pd->getImplementationType(),
false};

auto prim_desc = createPrimitiveDesc(key, getEngine());
auto weights = DnnlExtensionUtils::makeDescriptor(prim_desc.weights_desc());
// ignore the result since we just need to put the reordered weights into the cache
if (weightsNonTransposed) {
(void) prepareWeightMemory(weights, makeTransposedWeightDescriptor(weights));
} else {
(void) prepareWeightMemory(weights);
}
}
#endif

void FullyConnected::createPrimitive() {
#ifdef OV_CPU_WITH_MLAS
if (useMlas) {
Expand All @@ -341,6 +492,9 @@ void FullyConnected::createPrimitive() {
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
Node::createPrimitive();
appendPostOpArgs(attr, primArgs, postOpsArgs);
#if defined(OV_CPU_WITH_ACL)
prepareWeightsUsingDummyShape();
#endif
}

void FullyConnected::prepareParams() {
Expand Down Expand Up @@ -393,60 +547,7 @@ void FullyConnected::prepareParams() {
auto& engine = getEngine();

auto builder = [&engine](const FCKey& key) -> executorPtr {
// use conv1x1 primitive for computation
if (key.useConv1x1) {
auto prim_desc = createDescriptorInternalForConv(key.inp0, key.inp1, key.bias, key.out, key.attr, engine);
const bool found = DnnlExtensionUtils::find_implementation(prim_desc, brgconv_avx512_1x1);

if (found)
return std::make_shared<DnnlExecutor>(prim_desc);
}

// fallback to normal inner product primitive
auto inDesc = key.inp0->getDnnlDesc();
const auto& inDims = inDesc.get_dims(); // @TODO query + copy might be slow
if (inDims.size() == 3) {
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
inDesc = inDesc.reshape(normalizedInDims);
}
auto outDesc = key.out->getDnnlDesc();
const auto& outDims = outDesc.get_dims(); // @TODO query + copy might be slow

if (outDims.size() == 3) {
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
outDesc = outDesc.reshape(normalizedOutDims);
}
auto wghDescAny = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()),
key.inp1->getDataType(), memory::format_tag::any);
dnnl::inner_product_forward::primitive_desc prim_desc;
if (key.bias) {
prim_desc = dnnl::inner_product_forward::primitive_desc(
engine,
dnnl::prop_kind::forward_inference,
inDesc,
wghDescAny,
key.bias->getDnnlDesc(),
outDesc,
key.attr);
} else {
prim_desc = dnnl::inner_product_forward::primitive_desc(
engine,
dnnl::prop_kind::forward_inference,
inDesc,
wghDescAny,
outDesc,
key.attr);
}
auto first_desc = dnnl::inner_product_forward::primitive_desc(prim_desc.get());
const bool found = DnnlExtensionUtils::find_implementation(prim_desc, key.implType);

if (found)
return std::make_shared<DnnlExecutor>(prim_desc);

// For dynamic shape, the expected implement type kernel can support with dummy shape but
// not the run time inference shape. In this case, the implementation type will be
// ignored and the first available primitive descriptor will be chosen
return std::make_shared<DnnlExecutor>(first_desc);
return std::make_shared<DnnlExecutor>(createPrimitiveDesc(key, engine));
};

auto cache = context->getParamsCache();
Expand Down Expand Up @@ -487,7 +588,8 @@ void FullyConnected::prepareParams() {
}
#endif
if (weightsNonTransposed) {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc(), makeTransposedWeightDescriptor())->getPrimitive();
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc(),
makeTransposedWeightDescriptor(execPtr->getWeightDesc()))->getPrimitive();
} else {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc())->getPrimitive();
}
Expand Down Expand Up @@ -912,68 +1014,6 @@ void FullyConnected::initOptimalPrimitiveDescriptor() {
selectedPD->setConfig(config);
}

dnnl::convolution_forward::primitive_desc
FullyConnected::createDescriptorInternalForConv(DnnlMemoryDescCPtr inputDescPtr,
DnnlMemoryDescCPtr weightDescPtr,
DnnlMemoryDescCPtr biasDescPtr,
DnnlMemoryDescCPtr outputDescPtr,
const dnnl::primitive_attr& attr,
const dnnl::engine& engine) {
const dnnl::memory::desc &inputDesc = inputDescPtr->getDnnlDesc();
const dnnl::memory::desc &outputDesc = outputDescPtr->getDnnlDesc();
const dnnl::memory::desc &weightDesc = weightDescPtr->getDnnlDesc();
// make a fake shape: N, IC, W
auto inDims = inputDesc.get_dims();
dnnl::memory::dims normalizedInDims;
if (inDims.size() == 3) {
normalizedInDims = {inDims[0], inDims[2], inDims[1]};
} else if (inDims.size() == 2) {
normalizedInDims = {dnnl::memory::dim{1}, inDims[1], inDims[0]};
}
auto convInDesc = dnnl::memory::desc(normalizedInDims, inputDesc.get_data_type(), memory::format_tag::nwc);

// make a fake shape: N, OC, W
const auto& outDims = outputDesc.get_dims();
dnnl::memory::dims normalizedOutDims;
if (outDims.size() == 3) {
normalizedOutDims = { outDims[0], outDims[2], outDims[1]};
} else if (outDims.size() == 2) {
normalizedOutDims = { dnnl::memory::dim{1}, outDims[1], outDims[0]};
}
auto convOutDesc = dnnl::memory::desc(normalizedOutDims, outputDesc.get_data_type(), memory::format_tag::nwc);

// make a fake shape: OC, IC, 1
auto weightDims = weightDesc.get_dims();
dnnl::memory::dims normalizedWeightDims;
normalizedWeightDims = {static_cast<dnnl::memory::dim>(weightDims[0]),
static_cast<dnnl::memory::dim>(weightDims[1]),
dnnl::memory::dim{1}};
auto convWeightDescAny = dnnl::memory::desc(normalizedWeightDims, weightDesc.get_data_type(), dnnl::memory::format_tag::any);

if (biasDescPtr) {
return dnnl::convolution_forward::primitive_desc(
engine,
prop_kind::forward_inference,
dnnl::algorithm::convolution_direct,
convInDesc, convWeightDescAny, biasDescPtr->getDnnlDesc(), convOutDesc,
dnnl::memory::dims{1}, // stride
dnnl::memory::dims{0}, // dilation
dnnl::memory::dims{0}, // paddingL
dnnl::memory::dims{0}, // paddingR
attr);
} else {
return dnnl::convolution_forward::primitive_desc(
engine,
prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
convInDesc, convWeightDescAny, convOutDesc,
dnnl::memory::dims{1}, // stride
dnnl::memory::dims{0}, // dilation
dnnl::memory::dims{0}, // paddingL
dnnl::memory::dims{0}, // paddingR
attr);
}
}

bool FullyConnected::canBeExecutedInConv1x1() const {
bool retVal = false;
const auto inRank = getInputShapeAtPort(DATA_ID).getRank();
Expand Down Expand Up @@ -1106,7 +1146,7 @@ void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, std::ve
elementsCount);
}

DnnlMemoryDescPtr FullyConnected::makeTransposedWeightDescriptor() {
DnnlMemoryDescPtr FullyConnected::makeTransposedWeightDescriptor(DnnlMemoryDescPtr desc) {
if (!getParentEdgeAt(1)->getParent()->isConstant())
IE_THROW() << "Weight input is not const for node " << getName() << ".";
auto edgeMem = getParentEdgeAt(1)->getMemoryPtr();
Expand All @@ -1116,7 +1156,7 @@ DnnlMemoryDescPtr FullyConnected::makeTransposedWeightDescriptor() {
auto constDnnlMemOutDesc = edgeMem->getDescWithType<DnnlMemoryDesc>();
auto weightSrcDesc = constDnnlMemOutDesc->getDnnlDesc();
weightSrcDesc = {weightSrcDesc.get_dims(), weightSrcDesc.get_data_type(), memory::format_tag::ba};
weightSrcDesc = weightSrcDesc.reshape(execPtr->getWeightDesc()->getDnnlDesc().get_dims());
weightSrcDesc = weightSrcDesc.reshape(desc->getDnnlDesc().get_dims());

return DnnlExtensionUtils::makeDescriptor(weightSrcDesc);
}
Expand Down
6 changes: 4 additions & 2 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ class FullyConnected : public Node {
void executeMLAS();
void prepackMLASWeight();
#endif

#if defined(OV_CPU_WITH_ACL)
void prepareWeightsUsingDummyShape();
#endif
bool useWeightsDecompressionImpl = false;
std::vector<float> decompressionSubtract;
std::vector<float> decompressionMultiply;

// FC with transposed weights
bool weightsNonTransposed = false;
DnnlMemoryDescPtr makeTransposedWeightDescriptor();
DnnlMemoryDescPtr makeTransposedWeightDescriptor(DnnlMemoryDescPtr desc);
};

} // namespace node
Expand Down

0 comments on commit b7ac034

Please sign in to comment.