Skip to content

Commit

Permalink
Merge pull request #698 from igchor/umf_helpers_provider
Browse files Browse the repository at this point in the history
[umf] add extra poolMakeUnique overload to helpers
  • Loading branch information
igchor authored Jul 12, 2023
2 parents 0d992a0 + 6653648 commit 566169d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 89 deletions.
132 changes: 80 additions & 52 deletions source/common/umf_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <umf/memory_provider.h>
#include <umf/memory_provider_ops.h>

#include <array>
#include <functional>
#include <memory>
#include <stdexcept>
Expand Down Expand Up @@ -48,6 +49,55 @@ using provider_unique_handle_t =
} \
}

namespace detail {
template <typename T, typename ArgsTuple>
umf_result_t initialize(T *obj, ArgsTuple &&args) {
try {
auto ret = std::apply(&T::initialize,
std::tuple_cat(std::make_tuple(obj),
std::forward<ArgsTuple>(args)));
if (ret != UMF_RESULT_SUCCESS) {
delete obj;
}
return ret;
} catch (...) {
delete obj;
return UMF_RESULT_ERROR_UNKNOWN;
}
}

template <typename T, typename ArgsTuple>
umf_memory_pool_ops_t poolMakeUniqueOps() {
umf_memory_pool_ops_t ops;

ops.version = UMF_VERSION_CURRENT;
ops.initialize = [](umf_memory_provider_handle_t *providers,
size_t numProviders, void *params, void **obj) {
try {
*obj = new T;
} catch (...) {
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return detail::initialize<T>(
reinterpret_cast<T *>(*obj),
std::tuple_cat(std::make_tuple(providers, numProviders),
*reinterpret_cast<ArgsTuple *>(params)));
};
ops.finalize = [](void *obj) { delete reinterpret_cast<T *>(obj); };

UMF_ASSIGN_OP(ops, T, malloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, calloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, aligned_malloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, realloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, malloc_usable_size, ((size_t)0));
UMF_ASSIGN_OP_NORETURN(ops, T, free);
UMF_ASSIGN_OP(ops, T, get_last_allocation_error, UMF_RESULT_ERROR_UNKNOWN);

return ops;
}
} // namespace detail

/// @brief creates UMF memory provider based on given T type.
/// T should implement all functions defined by
/// umf_memory_provider_ops_t, except for finalize (it is
Expand All @@ -60,28 +110,15 @@ auto memoryProviderMakeUnique(Args &&...args) {

ops.version = UMF_VERSION_CURRENT;
ops.initialize = [](void *params, void **obj) {
auto *tuple = reinterpret_cast<decltype(argsTuple) *>(params);
T *provider;
try {
provider = new T;
*obj = new T;
} catch (...) {
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

*obj = provider;

try {
auto ret =
std::apply(&T::initialize,
std::tuple_cat(std::make_tuple(provider), *tuple));
if (ret != UMF_RESULT_SUCCESS) {
delete provider;
}
return ret;
} catch (...) {
delete provider;
return UMF_RESULT_ERROR_UNKNOWN;
}
return detail::initialize<T>(
reinterpret_cast<T *>(*obj),
*reinterpret_cast<decltype(argsTuple) *>(params));
};
ops.finalize = [](void *obj) { delete reinterpret_cast<T *>(obj); };

Expand All @@ -108,51 +145,42 @@ auto memoryProviderMakeUnique(Args &&...args) {
template <typename T, typename... Args>
auto poolMakeUnique(umf_memory_provider_handle_t *providers,
size_t numProviders, Args &&...args) {
umf_memory_pool_ops_t ops;
auto argsTuple = std::make_tuple(std::forward<Args>(args)...);
auto ops = detail::poolMakeUniqueOps<T, decltype(argsTuple)>();

ops.version = UMF_VERSION_CURRENT;
ops.initialize = [](umf_memory_provider_handle_t *providers,
size_t numProviders, void *params, void **obj) {
auto *tuple = reinterpret_cast<decltype(argsTuple) *>(params);
T *pool;
umf_memory_pool_handle_t hPool = nullptr;
auto ret = umfPoolCreate(&ops, providers, numProviders, &argsTuple, &hPool);
return std::pair<umf_result_t, pool_unique_handle_t>{
ret, pool_unique_handle_t(hPool, &umfPoolDestroy)};
}

try {
pool = new T;
} catch (...) {
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}
/// @brief creates UMF memory pool based on given T type.
/// This overload takes ownership of memory providers and destroys
/// them after memory pool is destroyed.
template <typename T, size_t N, typename... Args>
auto poolMakeUnique(std::array<provider_unique_handle_t, N> providers,
Args &&...args) {
auto argsTuple = std::make_tuple(std::forward<Args>(args)...);
auto ops = detail::poolMakeUniqueOps<T, decltype(argsTuple)>();

*obj = pool;
std::array<umf_memory_provider_handle_t, N> provider_handles;
for (size_t i = 0; i < N; i++) {
provider_handles[i] = providers[i].release();
}

try {
auto ret = std::apply(
&T::initialize,
std::tuple_cat(std::make_tuple(pool, providers, numProviders),
*tuple));
if (ret != UMF_RESULT_SUCCESS) {
delete pool;
}
return ret;
} catch (...) {
delete pool;
return UMF_RESULT_ERROR_UNKNOWN;
// capture providers and destroy them after the pool is destroyed
auto poolDestructor = [provider_handles](umf_memory_pool_handle_t hPool) {
umfPoolDestroy(hPool);
for (auto &provider : provider_handles) {
umfMemoryProviderDestroy(provider);
}
};
ops.finalize = [](void *obj) { delete reinterpret_cast<T *>(obj); };

UMF_ASSIGN_OP(ops, T, malloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, calloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, aligned_malloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, realloc, ((void *)nullptr));
UMF_ASSIGN_OP(ops, T, malloc_usable_size, ((size_t)0));
UMF_ASSIGN_OP_NORETURN(ops, T, free);
UMF_ASSIGN_OP(ops, T, get_last_allocation_error, UMF_RESULT_ERROR_UNKNOWN);

umf_memory_pool_handle_t hPool = nullptr;
auto ret = umfPoolCreate(&ops, providers, numProviders, &argsTuple, &hPool);
auto ret = umfPoolCreate(&ops, provider_handles.data(),
provider_handles.size(), &argsTuple, &hPool);
return std::pair<umf_result_t, pool_unique_handle_t>{
ret, pool_unique_handle_t(hPool, &umfPoolDestroy)};
ret, pool_unique_handle_t(hPool, std::move(poolDestructor))};
}

template <typename Type> umf_result_t &getPoolLastStatusRef() {
Expand Down
42 changes: 14 additions & 28 deletions test/unified_malloc_framework/memoryPoolAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,41 +133,27 @@ TEST_F(test, retrieveMemoryProviders) {
ASSERT_EQ(retProviders, providers);
}

template <typename Pool>
static auto
makePool(std::function<umf::provider_unique_handle_t()> makeProvider) {
auto providerUnique = makeProvider();
umf_memory_provider_handle_t provider = providerUnique.get();
auto pool = umf::poolMakeUnique<Pool>(&provider, 1).second;
auto dtor = [provider =
providerUnique.release()](umf_memory_pool_handle_t hPool) {
umfPoolDestroy(hPool);
umfMemoryProviderDestroy(provider);
};
return umf::pool_unique_handle_t(pool.release(), std::move(dtor));
}

INSTANTIATE_TEST_SUITE_P(mallocPoolTest, umfPoolTest, ::testing::Values([] {
return makePool<umf_test::malloc_pool>([] {
return umf_test::wrapProviderUnique(
nullProviderCreate());
});
}));
INSTANTIATE_TEST_SUITE_P(
mallocPoolTest, umfPoolTest, ::testing::Values([] {
return umf::poolMakeUnique<umf_test::malloc_pool, 1>(
{umf_test::wrapProviderUnique(nullProviderCreate())})
.second;
}));

INSTANTIATE_TEST_SUITE_P(
mallocProviderPoolTest, umfPoolTest, ::testing::Values([] {
return makePool<umf_test::proxy_pool>([] {
return umf::memoryProviderMakeUnique<umf_test::provider_malloc>()
.second;
});
return umf::poolMakeUnique<umf_test::proxy_pool, 1>(
{umf::memoryProviderMakeUnique<umf_test::provider_malloc>()
.second})
.second;
}));

INSTANTIATE_TEST_SUITE_P(
mallocMultiPoolTest, umfMultiPoolTest, ::testing::Values([] {
return makePool<umf_test::proxy_pool>([] {
return umf::memoryProviderMakeUnique<umf_test::provider_malloc>()
.second;
});
return umf::poolMakeUnique<umf_test::proxy_pool, 1>(
{umf::memoryProviderMakeUnique<umf_test::provider_malloc>()
.second})
.second;
}));

////////////////// Negative test cases /////////////////
Expand Down
13 changes: 4 additions & 9 deletions test/unified_malloc_framework/umf_pools/disjoint_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,13 @@ static usm::DisjointPool::Config poolConfig() {
}

static auto makePool() {
auto [ret, providerUnique] =
auto [ret, provider] =
umf::memoryProviderMakeUnique<umf_test::provider_malloc>();
EXPECT_EQ(ret, UMF_RESULT_SUCCESS);
auto provider = providerUnique.release();
auto [retp, pool] =
umf::poolMakeUnique<usm::DisjointPool>(&provider, 1, poolConfig());
auto [retp, pool] = umf::poolMakeUnique<usm::DisjointPool, 1>(
{std::move(provider)}, poolConfig());
EXPECT_EQ(retp, UMF_RESULT_SUCCESS);
auto dtor = [provider = provider](umf_memory_pool_handle_t hPool) {
umfPoolDestroy(hPool);
umfMemoryProviderDestroy(provider);
};
return umf::pool_unique_handle_t(pool.release(), std::move(dtor));
return std::move(pool);
}

INSTANTIATE_TEST_SUITE_P(disjointPoolTests, umfPoolTest,
Expand Down

0 comments on commit 566169d

Please sign in to comment.