Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cmake/WarningsUtils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ set(ONEMKL_WARNINGS "")
include(CheckCXXCompilerFlag)
macro(add_warning flag)
check_cxx_compiler_flag(${flag} IS_SUPPORTED)
message(STATUS "DBG '${flag}': ${IS_SUPPORTED}")
if(${IS_SUPPORTED})
list(APPEND ONEMKL_WARNINGS ${flag})
else()
Expand All @@ -40,7 +39,7 @@ add_warning("-Wshadow")
add_warning("-Wconversion")
add_warning("-Wpedantic")

message(STATUS "Using warnings: ${ONEMKL_WARNINGS}")
message(VERBOSE "Domains with warnings enabled use: ${ONEMKL_WARNINGS}")

# The onemkl_warnings target can be linked to any other target to enable warnings.
target_compile_options(onemkl_warnings INTERFACE ${ONEMKL_WARNINGS})
Expand Down
87 changes: 60 additions & 27 deletions tests/unit_tests/dft/source/descriptor_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ constexpr std::int64_t default_1d_lengths = 4;
const std::vector<std::int64_t> default_3d_lengths{ 124, 5, 3 };

template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
static void set_and_get_lengths(sycl::queue& sycl_queue) {
static void set_and_get_lengths() {
/* Negative Testing */
{
oneapi::mkl::dft::descriptor<precision, domain> descriptor{ default_3d_lengths };
Expand Down Expand Up @@ -70,8 +70,6 @@ static void set_and_get_lengths(sycl::queue& sycl_queue) {
descriptor.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimensions_after_set);
EXPECT_EQ(new_lengths, lengths_value);
EXPECT_EQ(dimensions, dimensions_after_set);

commit_descriptor(descriptor, sycl_queue);
}

/* >= 2D */
Expand Down Expand Up @@ -349,7 +347,7 @@ static void set_and_get_values() {
}

template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
static void get_readonly_values(sycl::queue& sycl_queue) {
static void get_readonly_values() {
oneapi::mkl::dft::descriptor<precision, domain> descriptor{ default_1d_lengths };

oneapi::mkl::dft::domain domain_value;
Expand All @@ -371,14 +369,10 @@ static void get_readonly_values(sycl::queue& sycl_queue) {
oneapi::mkl::dft::config_value commit_status;
descriptor.get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status);
EXPECT_EQ(commit_status, oneapi::mkl::dft::config_value::UNCOMMITTED);

commit_descriptor(descriptor, sycl_queue);
descriptor.get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status);
EXPECT_EQ(commit_status, oneapi::mkl::dft::config_value::COMMITTED);
}

template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
static void set_readonly_values(sycl::queue& sycl_queue) {
static void set_readonly_values() {
oneapi::mkl::dft::descriptor<precision, domain> descriptor{ default_1d_lengths };

EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::FORWARD_DOMAIN,
Expand All @@ -405,8 +399,16 @@ static void set_readonly_values(sycl::queue& sycl_queue) {
EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::COMMIT_STATUS,
oneapi::mkl::dft::config_value::UNCOMMITTED),
oneapi::mkl::invalid_argument);
}

template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
static void get_commited(sycl::queue& sycl_queue) {
oneapi::mkl::dft::descriptor<precision, domain> descriptor{ default_1d_lengths };
commit_descriptor(descriptor, sycl_queue);

oneapi::mkl::dft::config_value commit_status;
descriptor.get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status);
EXPECT_EQ(commit_status, oneapi::mkl::dft::config_value::COMMITTED);
}

template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
Expand Down Expand Up @@ -542,11 +544,22 @@ inline void swap_out_dead_queue(sycl::queue& sycl_queue) {
}

template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
int test(sycl::device* dev) {
static int test_getter_setter() {
set_and_get_lengths<precision, domain>();
set_and_get_strides<precision, domain>();
set_and_get_values<precision, domain>();
get_readonly_values<precision, domain>();
set_readonly_values<precision, domain>();

return !::testing::Test::HasFailure();
}

template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
int test_commit(sycl::device* dev) {
sycl::queue sycl_queue(*dev, exception_handler);

if constexpr (precision == oneapi::mkl::dft::precision::DOUBLE) {
if (!sycl_queue.get_device().has(sycl::aspect::fp64)) {
if (!dev->has(sycl::aspect::fp64)) {
std::cout << "Device does not support double precision." << std::endl;
return test_skipped;
}
Expand All @@ -559,41 +572,61 @@ int test(sycl::device* dev) {
return test_skipped;
}

set_and_get_lengths<precision, domain>(sycl_queue);
set_and_get_strides<precision, domain>();
set_and_get_values<precision, domain>();
get_readonly_values<precision, domain>(sycl_queue);
set_readonly_values<precision, domain>(sycl_queue);
get_commited<precision, domain>(sycl_queue);
recommit_values<precision, domain>(sycl_queue);
change_queue_causes_wait<precision, domain>(sycl_queue);
swap_out_dead_queue<precision, domain>(sycl_queue);

return !::testing::Test::HasFailure();
}

class DescriptorTests : public ::testing::TestWithParam<sycl::device*> {};
TEST(DescriptorTests, DescriptorTestsRealSingle) {
EXPECT_TRUE((
test_getter_setter<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::REAL>()));
}

TEST(DescriptorTests, DescriptorTestsRealDouble) {
EXPECT_TRUE((
test_getter_setter<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL>()));
}

TEST(DescriptorTests, DescriptorTestsComplexSingle) {
EXPECT_TRUE((test_getter_setter<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>()));
}

TEST(DescriptorTests, DescriptorTestsComplexDouble) {
EXPECT_TRUE((test_getter_setter<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::COMPLEX>()));
}

class DescriptorCommitTests : public ::testing::TestWithParam<sycl::device*> {};

TEST_P(DescriptorTests, DescriptorTestsRealSingle) {
TEST_P(DescriptorCommitTests, DescriptorCommitTestsRealSingle) {
EXPECT_TRUEORSKIP(
(test<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::REAL>(GetParam())));
(test_commit<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::REAL>(
GetParam())));
}

TEST_P(DescriptorTests, DescriptorTestsRealDouble) {
TEST_P(DescriptorCommitTests, DescriptorCommitTestsRealDouble) {
EXPECT_TRUEORSKIP(
(test<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL>(GetParam())));
(test_commit<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL>(
GetParam())));
}

TEST_P(DescriptorTests, DescriptorTestsComplexSingle) {
TEST_P(DescriptorCommitTests, DescriptorCommitTestsComplexSingle) {
EXPECT_TRUEORSKIP(
(test<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::COMPLEX>(GetParam())));
(test_commit<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::COMPLEX>(
GetParam())));
}

TEST_P(DescriptorTests, DescriptorTestsComplexDouble) {
TEST_P(DescriptorCommitTests, DescriptorCommitTestsComplexDouble) {
EXPECT_TRUEORSKIP(
(test<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::COMPLEX>(GetParam())));
(test_commit<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::COMPLEX>(
GetParam())));
}

INSTANTIATE_TEST_SUITE_P(DescriptorTestSuite, DescriptorTests, testing::ValuesIn(devices),
::DeviceNamePrint());
INSTANTIATE_TEST_SUITE_P(DescriptorCommitTestSuite, DescriptorCommitTests,
testing::ValuesIn(devices), ::DeviceNamePrint());

} // anonymous namespace