Skip to content

Commit

Permalink
fix possible memleak in SetOneDnnLayout (#2683)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyizh authored May 6, 2024
1 parent 47c0672 commit 4821154
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 32 deletions.
5 changes: 5 additions & 0 deletions itex/core/utils/onednn/onednn_layout_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ void OneDnnShape::SerializeOneDnnShape(unsigned char* buf,
ITEX_CHECK(buf_size >= GetSerializeBufferSize())
<< "Buffer size is too small to SerializeOneDnnShape";
*reinterpret_cast<OneDnnShapeData*>(buf) = data_;
std::memcpy(buf + sizeof(OneDnnShapeData), md_.data(), data_.md_size_);
}

void OneDnnShape::DeSerializeOneDnnShape(const unsigned char* buf,
Expand All @@ -137,8 +138,12 @@ void OneDnnShape::DeSerializeOneDnnShape(const unsigned char* buf,
ITEX_CHECK(buf_size >= GetSerializeBufferSize())
<< "Buffer size is too small in DeSerializeOneDnnShape";
data_ = *reinterpret_cast<const OneDnnShapeData*>(buf);
md_.resize(data_.md_size_);
std::memcpy(md_.data(), buf + sizeof(OneDnnShapeData), data_.md_size_);
} else {
data_.is_onednn_tensor_ = false;
data_.md_size_ = 0;
md_.clear();
}
}

Expand Down
31 changes: 17 additions & 14 deletions itex/core/utils/onednn/onednn_layout_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class OneDnnShape {
// Flag to indicate if the tensor is an OneDnn tensor or not
bool is_onednn_tensor_ = false;
OneDnnTensorFormat tf_data_format_ = OneDnnTensorFormat::FORMAT_INVALID;
// OneDnn layout
dnnl_memory_desc_t md_;
// TF dimension corresponding to this OneDnn dimension
dnnl_dims_t map_;
// TODO(itex): For Tensorflow, oneDNN Graph shape and stride are actually
Expand All @@ -64,8 +62,10 @@ class OneDnnShape {
dnnl_dims_t stride_;
// layout_id for OneDnn Graph logical tensor
int64_t layout_id_ = INVALID_LLGA_ID;
size_t md_size_ = 0;
} OneDnnShapeData;
OneDnnShapeData data_;
std::vector<uint8_t> md_;

public:
OneDnnShape() {
Expand Down Expand Up @@ -96,32 +96,35 @@ class OneDnnShape {
// OneDnnShape object.
inline dnnl::memory::dims GetSizesAsOneDnnDims() const {
ITEX_CHECK_EQ(data_.is_onednn_tensor_, true);
dnnl_dims_t* dims_c;
int ndims = 0;
dnnl_memory_desc_query(data_.md_, dnnl_query_ndims_s32, &ndims);
if (ndims == 0) return dnnl::memory::dims();
dnnl_memory_desc_query(data_.md_, dnnl_query_dims, &dims_c);
return dnnl::memory::dims(*dims_c, *dims_c + ndims);
dnnl_memory_desc_t tmp;
dnnl_memory_desc_create_with_blob(&tmp, md_.data());
return dnnl::memory::desc(tmp).get_dims();
}

// Get DataType
inline dnnl::memory::data_type GetElemType() const {
dnnl_data_type_t dt;
dnnl_memory_desc_query(data_.md_, dnnl_query_data_type, &dt);
return static_cast<dnnl::memory::data_type>(dt);
dnnl_memory_desc_t tmp;
dnnl_memory_desc_create_with_blob(&tmp, md_.data());
return dnnl::memory::desc(tmp).get_data_type();
}

// Return TensorShape that describes the Tensorflow shape of the tensor
// represented by this OneDnnShape.
TensorShape GetTfShape() const;
inline void SetOneDnnLayout(const dnnl::memory::desc& md) {
dnnl_memory_desc_clone(&data_.md_, md.get());
dnnl_memory_desc_get_blob(nullptr, &data_.md_size_, md.get());
md_.resize(data_.md_size_);
dnnl_memory_desc_get_blob(md_.data(), &data_.md_size_, md.get());
}

// Get memory desc for OneDnn layout
inline const dnnl::memory::desc GetOneDnnLayout() const {
dnnl_memory_desc_t tmp;
dnnl_memory_desc_clone(&tmp, data_.md_);
dnnl_memory_desc_create_with_blob(&tmp, md_.data());
// According to oneDNN, this will not cause memory leak if we
// construct a memory descriptor from a C API ::dnnl_memory_desc_t
// handle. The resulting handle is not weak and the C handle will be
// destroyed during the destruction of the C++ object.
return dnnl::memory::desc(tmp);
}
// Get memory desc for TF layout, only used in onednntotf op
Expand Down Expand Up @@ -159,7 +162,7 @@ class OneDnnShape {

// Get Size of OneDnnShapeData, it is used to allocate buffer for meta tensor
inline size_t GetSerializeBufferSize() const {
return sizeof(OneDnnShapeData);
return sizeof(OneDnnShapeData) + data_.md_size_;
}

// Set shape of logical tensor.
Expand Down
32 changes: 16 additions & 16 deletions itex/core/utils/onednn/onednn_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ void ReorderMemory(const OpKernelContext& context,
ReorderMemoryInternal(src_memory, reorder_memory, onednn_stream);
}

// TF datatype and shape is meaningless for some tensors, such as scratchpad
// tensor and memory desc tensor in weight cache. These tensors are only used
// in OneDnn primitive, not related to Tensorflow. We only need to choose a
// short length datatype, ensure the it is divisible by allocated buffer.
using ShortDT = uint8;

template <typename T>
bool WeightCacheManager<T>::IsEmpty() TF_LOCKS_EXCLUDED(mu_) {
tf_shared_lock lock(&mu_);
Expand Down Expand Up @@ -86,20 +80,23 @@ void WeightCacheManager<T>::SetCache(
ReorderMemory(*context, &weight_mem, &weight_reorder_mem, onednn_engine);

// Cache the memory descriptor
size_t blob_size;
dnnl_memory_desc_get_blob(nullptr, &blob_size, weight_expected_md.get());
std::vector<uint8_t> weight_expected_md_blob(blob_size);
dnnl_memory_desc_get_blob(weight_expected_md_blob.data(), &blob_size,
weight_expected_md.get());
Tensor* weight_md_cached_tensor = nullptr;
TensorShape weight_md_tf_shape;
weight_md_tf_shape.AddDim(sizeof(weight_expected_md) / sizeof(ShortDT));
weight_md_tf_shape.AddDim(blob_size);

AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
OP_REQUIRES_OK(context,
context->allocate_persistent(
DataTypeToEnum<ShortDT>::value, weight_md_tf_shape,
DataTypeToEnum<uint8_t>::value, weight_md_tf_shape,
&weight_cached_md_, &weight_md_cached_tensor, alloc_attr));
dnnl_memory_desc_t c_weight_expected_md;
dnnl_memory_desc_clone(&c_weight_expected_md, weight_expected_md.get());
*reinterpret_cast<dnnl_memory_desc_t*>(
weight_md_cached_tensor->flat<ShortDT>().data()) = c_weight_expected_md;
std::memcpy(weight_md_cached_tensor->flat<uint8_t>().data(),
weight_expected_md_blob.data(), weight_expected_md_blob.size());
}

template <typename T>
Expand All @@ -112,10 +109,13 @@ T* WeightCacheManager<T>::GetCache(OpKernelContext* context,

// Check if the memory descriptor of the cached weight is same as
// expected_md. if so use the cached memory, else return nullptr
if (weight_cached_md->flat<ShortDT>().size()) {
dnnl::memory::desc* cached_md = reinterpret_cast<dnnl::memory::desc*>(
const_cast<ShortDT*>(weight_cached_md->flat<ShortDT>().data()));
if (*cached_md == expected_md) {
if (weight_cached_md->flat<uint8_t>().size()) {
std::vector<uint8_t> weight_cached_md_blob(
weight_cached_md->flat<uint8_t>().data(),
weight_cached_md->flat<uint8_t>().data() +
weight_cached_md->flat<uint8_t>().size());
dnnl::memory::desc cached_md = dnnl::memory::desc(weight_cached_md_blob);
if (cached_md == expected_md) {
return reinterpret_cast<T*>(
const_cast<T*>(weight_cached_data->flat<T>().data()));
} else {
Expand Down
4 changes: 2 additions & 2 deletions itex/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def itex_workspace(path_prefix = "", tf_repo_name = ""):
system_build_file = clean_dep("//third_party/systemlibs:pybind11.BUILD"),
)

# v3.3
_ONEDNN_CPU_COMMIT = "08fea71aff4c273e34579e86396405f95d34aa74"
# main 20240329
_ONEDNN_CPU_COMMIT = "242d4d9"

new_git_repository(
name = "onednn_cpu",
Expand Down
1 change: 1 addition & 0 deletions third_party/onednn/onednn_cpu.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ _DNNL_CPU_COMMON = {
"#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 0",
"#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 0",
"#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 0",
"#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0",
}

_DNNL_RUNTIME_TBB = {
Expand Down
1 change: 1 addition & 0 deletions third_party/onednn/onednn_cpu_eigen.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ _DNNL_CPU_COMMON = {
"#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 0",
"#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 0",
"#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 0",
"#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0",
}

_DNNL_RUNTIME_TBB = {
Expand Down

0 comments on commit 4821154

Please sign in to comment.