Skip to content

Implement new DNNL1.x MatMul primitive cache #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tensorflow/core/kernels/mkl_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,13 @@ class MklMatMulOp : public OpKernel {
const int index_transa = transa ? 1 : 0;
const int index_transb = transb ? 1 : 0;

Tensor c_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
#ifdef ENABLE_MKLDNN_V1
const char ftrans[] = {'N', 'T', 'C'};
dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k,
alpha, a, lda, b, ldb, beta,
c_float.flat<float>().data(), ldc);
alpha, a, lda, b, ldb, beta, c, ldc);
#else
Tensor c_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
const char* const ftrans[] = {"N", "T", "C"};

// MKL-DNN only supports the Fortran API and requires column major while
Expand All @@ -201,8 +200,8 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const mkldnn_bfloat16_t*>(b), &ldb,
reinterpret_cast<const mkldnn_bfloat16_t*>(a), &lda,
&beta, c_float.flat<float>().data(), &ldc);
#endif // ENABLE_MKLDNN_V1
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
#endif // ENABLE_MKLDNN_V1
}
#endif // ENABLE_INTEL_MKL_BFLOAT16

Expand Down
268 changes: 214 additions & 54 deletions tensorflow/core/kernels/mkl_matmul_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,33 +516,191 @@ class MklDnnMatMulOpBase : public OpKernel {

// MatMul support for bfloat16 and int8 types is introduced in DNNLv1.2.
#ifdef ENABLE_MKLDNN_V1

using mkldnn::matmul;

namespace {

void dnnl_gemm_exec(const memory::desc& a_md, const memory::desc& b_md,
const memory::desc& c_md, const void* a, const void* b,
void* c, const primitive_attr& attr) {
// Create a MatMul primitive
mkldnn::engine cpu_engine = mkldnn::engine(ENGINE_CPU, 0);
mkldnn::matmul::desc matmul_desc(a_md, b_md, c_md);
mkldnn::matmul::primitive_desc matmul_pd(matmul_desc, attr, cpu_engine);
mkldnn::matmul matmul_prim(matmul_pd);
// Wrap raw pointers into DNNL memory objects
mkldnn::memory a_memory(a_md, cpu_engine, const_cast<void*>(a));
mkldnn::memory b_memory(b_md, cpu_engine, const_cast<void*>(b));
mkldnn::memory c_memory(c_md, cpu_engine, c);
// Execute the MatMul primitive.
// Since here all shapes and parameters are static, please note that we
// don't need to pass alpha (scales) again, as they are already hard-coded
// in the primitive descriptor. Also, we are not allowed to change the
// shapes of matrices A, B, and C -- they should exactly match
// the memory descriptors passed to MatMul operation descriptor.
mkldnn::stream s(cpu_engine);
matmul_prim.execute(s, {{DNNL_ARG_SRC, a_memory},
{DNNL_ARG_WEIGHTS, b_memory},
{ DNNL_ARG_DST,
c_memory }});
s.wait();
}
struct MklMatMulParams {
memory::dims a_dims;
memory::dims b_dims;
memory::dims c_dims;
memory::dims a_strides;
memory::dims b_strides;
memory::dims c_strides;

MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims,
memory::dims a_strides, memory::dims b_strides,
memory::dims c_strides)
: a_dims(a_dims),
b_dims(b_dims),
c_dims(c_dims),
a_strides(a_strides),
b_strides(b_strides),
c_strides(c_strides) {}
};

template <typename T>
class MklMatMulPrimitive : public MklPrimitive {
public:
explicit MklMatMulPrimitive(const MklMatMulParams& params)
: cpu_engine_(ENGINE_CPU, 0) {
context_.stream.reset(new CPU_STREAM(cpu_engine_));
// Create matmul primitive
if (context_.matmul_prim == nullptr) {
Setup(params);
}
}

~MklMatMulPrimitive() {}

void Execute(const T* a_data, const T* b_data, T* c_data) {
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));

execute_primitives(context_.matmul_primtimives, context_.stream,
context_.net_args);

// After execution, set data handle back
context_.a_mem->set_data_handle(DummyData);
context_.b_mem->set_data_handle(DummyData);
context_.c_mem->set_data_handle(DummyData);
}

private:
// Primitive reuse context for MatMul op
struct MklMatMulContext {
// MKL-DNN memory.
std::shared_ptr<mkldnn::memory> a_mem;
std::shared_ptr<mkldnn::memory> b_mem;
std::shared_ptr<mkldnn::memory> c_mem;

// Descriptor and primitive-descriptor for MatMul.
std::shared_ptr<matmul::desc> desc;
std::shared_ptr<matmul::primitive_desc> prim_desc;

// Memory descriptors.
std::shared_ptr<mkldnn::memory::desc> a_md;
std::shared_ptr<mkldnn::memory::desc> b_md;
std::shared_ptr<mkldnn::memory::desc> c_md;

// MatMul primitive.
std::shared_ptr<mkldnn::primitive> matmul_prim;
std::shared_ptr<mkldnn::stream> stream;
std::vector<mkldnn::primitive> matmul_primtimives;
std::vector<std::unordered_map<int, memory>> net_args;

MklMatMulContext()
: a_mem(nullptr),
b_mem(nullptr),
c_mem(nullptr),
desc(nullptr),
prim_desc(nullptr),
a_md(nullptr),
b_md(nullptr),
c_md(nullptr),
matmul_prim(nullptr),
stream(nullptr) {}
};

void Setup(const MklMatMulParams& params) {
// Create MatMul descriptor and primitive descriptor.
context_.a_md.reset(
new memory::desc({params.a_dims}, MklDnnType<T>(), params.a_strides));

context_.b_md.reset(
new memory::desc({params.b_dims}, MklDnnType<T>(), params.b_strides));

context_.c_md.reset(
new memory::desc({params.c_dims}, MklDnnType<T>(), params.c_strides));

// Create matmul.
context_.desc.reset(
new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
context_.prim_desc.reset(
new matmul::primitive_desc(*context_.desc, cpu_engine_));

// Create memory primitive based on dummy data.
context_.a_mem.reset(
new mkldnn::memory(*context_.a_md, cpu_engine_, DummyData));
context_.b_mem.reset(
new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
context_.c_mem.reset(
new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));

// Create matmul primitive.
context_.matmul_prim.reset(new mkldnn::matmul(*context_.prim_desc));
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.a_mem},
{MKLDNN_ARG_WEIGHTS, *context_.b_mem},
{MKLDNN_ARG_DST, *context_.c_mem}});

context_.matmul_primtimives.push_back(*context_.matmul_prim);
return;
}

struct MklMatMulContext context_;
engine cpu_engine_;
};

template <typename T>
class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklMatMulPrimitive<T>* Get(const MklMatMulParams& params,
bool do_not_cache) {
MklMatMulPrimitive<T>* matmul_prim = nullptr;

if (do_not_cache) {
// Always create new primitive
matmul_prim = new MklMatMulPrimitive<T>(params);
} else {
// Try to find a suitable one in pool
matmul_prim = dynamic_cast<MklMatMulPrimitive<T>*>(
MklMatMulPrimitiveFactory<T>::GetInstance().GetMklMatMul(params));
if (matmul_prim == nullptr) {
matmul_prim = new MklMatMulPrimitive<T>(params);
MklMatMulPrimitiveFactory<T>::GetInstance().SetMklMatMul(params,
matmul_prim);
}
}

return matmul_prim;
}

private:
MklMatMulPrimitiveFactory() {}
~MklMatMulPrimitiveFactory() {}

static MklMatMulPrimitiveFactory& GetInstance() {
static MklMatMulPrimitiveFactory instance_;
return instance_;
}

static string CreateKey(const MklMatMulParams& params) {
string prefix = "matmul_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(params.a_dims);
key_creator.AddAsKey(params.b_dims);
key_creator.AddAsKey(params.c_dims);
key_creator.AddAsKey(params.a_strides);
key_creator.AddAsKey(params.b_strides);
key_creator.AddAsKey(params.c_strides);
key_creator.AddAsKey(typeid(T).name());

return key_creator.GetKey();
}

MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
string key = CreateKey(params);
return this->GetOp(key);
}

void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
string key = CreateKey(params);
this->SetOp(key, op);
}
};

template <typename T>
void dnnl_gemm_batch(const std::vector<bool>& transa,
Expand Down Expand Up @@ -589,45 +747,47 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
!transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]};
dims c_strides = dims{m[0] * n[0], n[0], 1};

// Prepare memory descriptors
memory::desc a_md(a_sizes, MklDnnType<T>(), a_strides);
memory::desc b_md(b_sizes, MklDnnType<T>(), b_strides);
memory::desc c_md(c_sizes, MklDnnType<T>(), c_strides);
// Create attributes (to handle alpha and beta if necessary)
mkldnn::primitive_attr attr;
if (alpha[0] != 1.f) attr.set_output_scales(/* mask */ 0, {alpha[0]});
if (beta[0] != 0.f) {
mkldnn::post_ops po;
po.append_sum(beta[0]);
attr.set_post_ops(po);
}
dnnl_gemm_exec(a_md, b_md, c_md, static_cast<const void*>(a),
static_cast<const void*>(b), static_cast<void*>(c), attr);
// MklMatMul uses const alpha and beta, make guarantee here to ensure
// they are never changed.
DCHECK_EQ(alpha, 1.0f);
DCHECK_EQ(beta, 0.f);

MklMatMulParams params(a_sizes, b_sizes, c_sizes, a_strides, b_strides,
c_strides);
MklMatMulPrimitive<T>* matmul_prim =
MklMatMulPrimitiveFactory<T>::Get(params, 0);

// Execute matmul primitive.
matmul_prim->Execute(a, b, c);
}

template <typename T>
void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
float beta, float* c, int64_t ldc) {
float beta, T* c, int64_t ldc) {
using dims = mkldnn::memory::dims;

// Prepare strides based on the transa and transb flags: transposed
// matrices have strides swapped
dims a_dims = dims{m, k};
dims b_dims = dims{k, n};
dims c_dims = dims{m, n};
dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
// Prepare memory descriptors
memory::desc a_md({m, k}, MklDnnType<T>(), a_strides);
memory::desc b_md({k, n}, MklDnnType<T>(), b_strides);
memory::desc c_md({m, n}, MklDnnType<float>(), {ldc, 1});
// Create attributes (to handle alpha and beta if necessary)
mkldnn::primitive_attr attr;
if (alpha != 1.f) attr.set_output_scales(/* mask */ 0, {alpha});
if (beta != 0.f) {
mkldnn::post_ops po;
po.append_sum(beta);
attr.set_post_ops(po);
}
dnnl_gemm_exec(a_md, b_md, c_md, static_cast<const void*>(a),
static_cast<const void*>(b), static_cast<void*>(c), attr);
dims c_strides = dims{ldc, 1};

// MklMatMul uses const alpha and beta, make guarantee here to ensure
// they are never changed.
DCHECK_EQ(alpha, 1.0f);
DCHECK_EQ(beta, 0.f);

MklMatMulParams params(a_dims, b_dims, c_dims, a_strides, b_strides,
c_strides);
MklMatMulPrimitive<T>* matmul_prim =
MklMatMulPrimitiveFactory<T>::Get(params, 0);

// Execute matmul primitive.
matmul_prim->Execute(a, b, c);
}

} // anonymous namespace
Expand Down