Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 11 additions & 36 deletions tests/unit_tests/dft/include/compute_inplace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,7 @@ int DFT_Test<precision, domain>::test_in_place_buffer() {
{
sycl::buffer<FwdInputType, 1> inout_buf{ inout_host };

try {
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType>(descriptor, inout_buf);
}
catch (oneapi::mkl::unimplemented& e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType>(descriptor, inout_buf);

{
auto acc_host = inout_buf.template get_host_access();
Expand Down Expand Up @@ -158,14 +152,8 @@ int DFT_Test<precision, domain>::test_in_place_buffer() {
commit_descriptor(descriptor, sycl_queue);
}

try {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdInputType>(descriptor, inout_buf);
}
catch (oneapi::mkl::unimplemented& e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdInputType>(descriptor, inout_buf);
}

if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
Expand Down Expand Up @@ -231,16 +219,10 @@ int DFT_Test<precision, domain>::test_in_place_USM() {
std::copy(input.begin(), input.end(), inout.begin());
}

try {
std::vector<sycl::event> dependencies;
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType>(descriptor, inout.data(),
dependencies)
.wait();
}
catch (oneapi::mkl::unimplemented& e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
std::vector<sycl::event> no_dependencies;
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType>(descriptor, inout.data(),
no_dependencies)
.wait_and_throw();

if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
std::vector<FwdInputType> conjugate_even_ref =
Expand All @@ -262,17 +244,10 @@ int DFT_Test<precision, domain>::test_in_place_USM() {
commit_descriptor(descriptor, sycl_queue);
}

try {
std::vector<sycl::event> dependencies;
sycl::event done =
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdInputType>(descriptor, inout.data());
done.wait();
}
catch (oneapi::mkl::unimplemented& e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
sycl::event done =
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdInputType>(descriptor, inout.data(), no_dependencies);
done.wait_and_throw();

if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
for (std::size_t j = 0; j < real_first_dims; j++) {
Expand Down
39 changes: 10 additions & 29 deletions tests/unit_tests/dft/include/compute_inplace_real_real.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,10 @@ int DFT_Test<precision, domain>::test_in_place_real_real_USM() {
std::copy(input_re.begin(), input_re.end(), inout_re.begin());
std::copy(input_im.begin(), input_im.end(), inout_im.begin());

std::vector<sycl::event> dependencies;
try {
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType>(
descriptor, inout_re.data(), inout_im.data(), dependencies)
.wait();
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
std::vector<sycl::event> no_dependencies;
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType>(
descriptor, inout_re.data(), inout_im.data(), no_dependencies)
.wait_and_throw();

std::vector<FwdOutputType> output_data(size_total);
for (std::size_t i = 0; i < output_data.size(); ++i) {
Expand All @@ -73,8 +67,8 @@ int DFT_Test<precision, domain>::test_in_place_real_real_USM() {

oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
PrecisionType>(descriptor, inout_re.data(),
inout_im.data(), dependencies)
.wait();
inout_im.data(), no_dependencies)
.wait_and_throw();

for (std::size_t i = 0; i < output_data.size(); ++i) {
output_data[i] = { inout_re[i], inout_im[i] };
Expand Down Expand Up @@ -123,14 +117,8 @@ int DFT_Test<precision, domain>::test_in_place_real_real_buffer() {
sycl::buffer<PrecisionType, 1> inout_im_buf{ host_inout_im.data(),
sycl::range<1>(size_total) };

try {
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType>(descriptor, inout_re_buf,
inout_im_buf);
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType>(descriptor, inout_re_buf,
inout_im_buf);

{
auto acc_inout_re = inout_re_buf.template get_host_access();
Expand All @@ -144,15 +132,8 @@ int DFT_Test<precision, domain>::test_in_place_real_real_buffer() {
std::cout));
}

try {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
PrecisionType>(descriptor, inout_re_buf,
inout_im_buf);
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
PrecisionType>(descriptor, inout_re_buf, inout_im_buf);

{
auto acc_inout_re = inout_re_buf.template get_host_access();
Expand Down
48 changes: 12 additions & 36 deletions tests/unit_tests/dft/include/compute_out_of_place.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,8 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
sycl::buffer<FwdOutputType, 1> bwd_buf{ sycl::range<1>(
cast_unsigned(backward_distance * batches)) };

try {
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
descriptor, fwd_buf, bwd_buf);
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
descriptor, fwd_buf, bwd_buf);

{
auto acc_bwd = bwd_buf.template get_host_access();
Expand Down Expand Up @@ -99,15 +93,9 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
commit_descriptor(descriptor, sycl_queue);
}

try {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdOutputType, FwdInputType>(descriptor, bwd_buf,
fwd_buf);
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdOutputType, FwdInputType>(descriptor, bwd_buf,
fwd_buf);
}

EXPECT_TRUE(check_equal_vector(fwd_data.data(), input.data(), input.size(), abs_error_margin,
Expand Down Expand Up @@ -147,15 +135,9 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
std::vector<FwdOutputType, decltype(ua_output)> bwd(cast_unsigned(backward_distance * batches),
ua_output);

try {
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
descriptor, fwd.data(), bwd.data(), no_dependencies)
.wait();
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
descriptor, fwd.data(), bwd.data(), no_dependencies)
.wait_and_throw();

{
auto bwd_iter = bwd.begin();
Expand All @@ -181,16 +163,10 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
commit_descriptor(descriptor, sycl_queue);
}

try {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdOutputType, FwdInputType>(descriptor, bwd.data(),
fwd.data(), no_dependencies)
.wait();
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>, FwdOutputType,
FwdInputType>(descriptor, bwd.data(), fwd.data(),
no_dependencies)
.wait_and_throw();

EXPECT_TRUE(check_equal_vector(fwd.data(), input.data(), input.size(), abs_error_margin,
rel_error_margin, std::cout));
Expand Down
51 changes: 14 additions & 37 deletions tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,23 @@ int DFT_Test<precision, domain>::test_out_of_place_real_real_USM() {
std::copy(input_re.begin(), input_re.end(), in_re.begin());
std::copy(input_im.begin(), input_im.end(), in_im.begin());

std::vector<sycl::event> dependencies;
std::vector<sycl::event> no_dependencies;

try {
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType, PrecisionType>(
descriptor, in_re.data(), in_im.data(), out_re.data(), out_im.data(), dependencies)
.wait();
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType, PrecisionType>(
descriptor, in_re.data(), in_im.data(), out_re.data(), out_im.data(), no_dependencies)
.wait_and_throw();
std::vector<FwdOutputType> output_data(size_total);
for (std::size_t i = 0; i < output_data.size(); ++i) {
output_data[i] = { out_re[i], out_im[i] };
}
EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), output_data.size(),
abs_error_margin, rel_error_margin, std::cout));

try {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
PrecisionType, PrecisionType>(
descriptor, out_re.data(), out_im.data(), out_back_re.data(), out_back_im.data())
.wait();
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
PrecisionType, PrecisionType>(
descriptor, out_re.data(), out_im.data(), out_back_re.data(), out_back_im.data(),
no_dependencies)
.wait_and_throw();

for (std::size_t i = 0; i < output_data.size(); ++i) {
output_data[i] = { out_back_re[i], out_back_im[i] };
Expand Down Expand Up @@ -134,14 +123,8 @@ int DFT_Test<precision, domain>::test_out_of_place_real_real_buffer() {
sycl::buffer<PrecisionType, 1> out_back_dev_re{ sycl::range<1>(size_total) };
sycl::buffer<PrecisionType, 1> out_back_dev_im{ sycl::range<1>(size_total) };

try {
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType, PrecisionType>(
descriptor, in_dev_re, in_dev_im, out_dev_re, out_dev_im);
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType, PrecisionType>(
descriptor, in_dev_re, in_dev_im, out_dev_re, out_dev_im);

{
auto acc_out_re = out_dev_re.template get_host_access();
Expand All @@ -155,15 +138,9 @@ int DFT_Test<precision, domain>::test_out_of_place_real_real_buffer() {
std::cout));
}

try {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
PrecisionType, PrecisionType>(
descriptor, out_dev_re, out_dev_im, out_back_dev_re, out_back_dev_im);
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
PrecisionType, PrecisionType>(
descriptor, out_dev_re, out_dev_im, out_back_dev_re, out_back_dev_im);

{
auto acc_back_out_re = out_back_dev_re.template get_host_access();
Expand Down
18 changes: 12 additions & 6 deletions tests/unit_tests/dft/source/compute_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,18 @@ class ComputeTests_real_real_out_of_place

#define INSTANTIATE_TEST(PRECISION, DOMAIN, PLACE, LAYOUT, STORAGE) \
TEST_P(ComputeTests##_##LAYOUT##PLACE, DOMAIN##_##PRECISION##_##PLACE##_##LAYOUT##STORAGE) { \
auto test = \
DFT_Test<oneapi::mkl::dft::precision::PRECISION, oneapi::mkl::dft::domain::DOMAIN>{ \
std::get<0>(GetParam()), std::get<1>(GetParam()).sizes, \
std::get<1>(GetParam()).batches \
}; \
EXPECT_TRUEORSKIP(test.test_##PLACE##_##LAYOUT##STORAGE()); \
try { \
auto test = \
DFT_Test<oneapi::mkl::dft::precision::PRECISION, \
oneapi::mkl::dft::domain::DOMAIN>{ std::get<0>(GetParam()), \
std::get<1>(GetParam()).sizes, \
std::get<1>(GetParam()).batches }; \
EXPECT_TRUEORSKIP(test.test_##PLACE##_##LAYOUT##STORAGE()); \
} \
catch (oneapi::mkl::unimplemented & e) { \
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; \
GTEST_SKIP(); \
} \
}

#define INSTANTIATE_TEST_DIMENSIONS_PRECISION_DOMAIN(PLACE, LAYOUT, STORAGE) \
Expand Down