Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 203 additions & 65 deletions tests/unit_tests/dft/include/compute_inplace.hpp

Large diffs are not rendered by default.

42 changes: 32 additions & 10 deletions tests/unit_tests/dft/include/compute_inplace_real_real.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,23 @@ int DFT_Test<precision, domain>::test_in_place_real_real_USM() {
}

try {
descriptor_t descriptor{ size };
descriptor_t descriptor{ sizes };

descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::INPLACE);
descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE,
oneapi::mkl::dft::config_value::REAL_REAL);
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
commit_descriptor(descriptor, sycl_queue);

auto ua_input = usm_allocator_t<PrecisionType>(cxt, *dev);

std::vector<PrecisionType, decltype(ua_input)> inout_re(size, ua_input);
std::vector<PrecisionType, decltype(ua_input)> inout_im(size, ua_input);
std::vector<PrecisionType, decltype(ua_input)> inout_re(size_total, ua_input);
std::vector<PrecisionType, decltype(ua_input)> inout_im(size_total, ua_input);
std::copy(input_re.begin(), input_re.end(), inout_re.begin());
std::copy(input_im.begin(), input_im.end(), inout_im.begin());

Expand All @@ -51,13 +56,19 @@ int DFT_Test<precision, domain>::test_in_place_real_real_USM() {
descriptor, inout_re.data(), inout_im.data(), dependencies);
done.wait();

descriptor_t descriptor_back{ size };
descriptor_t descriptor_back{ sizes };

descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::INPLACE);
descriptor_back.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE,
oneapi::mkl::dft::config_value::REAL_REAL);
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size));
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE,
(1.0 / forward_elements));
descriptor_back.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor_back.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
descriptor_back.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
commit_descriptor(descriptor_back, sycl_queue);

done =
Expand Down Expand Up @@ -86,27 +97,38 @@ int DFT_Test<precision, domain>::test_in_place_real_real_buffer() {
}

try {
descriptor_t descriptor{ size };
descriptor_t descriptor{ sizes };

descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::INPLACE);
descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE,
oneapi::mkl::dft::config_value::REAL_REAL);
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
commit_descriptor(descriptor, sycl_queue);

sycl::buffer<PrecisionType, 1> inout_re_buf{ input_re.data(), sycl::range<1>(size) };
sycl::buffer<PrecisionType, 1> inout_im_buf{ input_im.data(), sycl::range<1>(size) };
sycl::buffer<PrecisionType, 1> inout_re_buf{ input_re.data(), sycl::range<1>(size_total) };
sycl::buffer<PrecisionType, 1> inout_im_buf{ input_im.data(), sycl::range<1>(size_total) };

oneapi::mkl::dft::compute_forward<descriptor_t, PrecisionType>(descriptor, inout_re_buf,
inout_im_buf);

descriptor_t descriptor_back{ size };
descriptor_t descriptor_back{ sizes };

descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::INPLACE);
descriptor_back.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE,
oneapi::mkl::dft::config_value::REAL_REAL);
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size));
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE,
(1.0 / forward_elements));
descriptor_back.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor_back.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
descriptor_back.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
static_cast<std::int64_t>(forward_elements));
commit_descriptor(descriptor_back, sycl_queue);

oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor_back)>,
Expand Down
101 changes: 84 additions & 17 deletions tests/unit_tests/dft/include/compute_out_of_place.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,60 @@
#define ONEMKL_COMPUTE_OUT_OF_PLACE_HPP

#include "compute_tester.hpp"
#include <numeric>

template <oneapi::mkl::dft::domain domain>
std::int64_t get_backward_row_size(const std::vector<std::int64_t> &sizes) noexcept {
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
return sizes.back() / 2 + 1;
}
else {
return sizes.back();
}
}

/* Note: There is no implementation for Domain Real */
template <oneapi::mkl::dft::precision precision, oneapi::mkl::dft::domain domain>
int DFT_Test<precision, domain>::test_out_of_place_buffer() {
if (!init(MemoryAccessModel::buffer)) {
return test_skipped;
}

const size_t bwd_size = domain == oneapi::mkl::dft::domain::REAL ? (size / 2) + 1 : size;
const auto backward_distance = std::accumulate(
sizes.begin(), sizes.end() - 1, get_backward_row_size<domain>(sizes), std::multiplies<>());

descriptor_t descriptor{ size };
descriptor_t descriptor{ sizes };
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES,
complex_strides.data());
}
commit_descriptor(descriptor, sycl_queue);

descriptor_t descriptor_back{ size };
descriptor_t descriptor_back{ sizes };
descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size));
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE,
(1.0 / forward_elements));
descriptor_back.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor_back.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements);
descriptor_back.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor_back.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES,
complex_strides.data());
}
commit_descriptor(descriptor_back, sycl_queue);

std::vector<FwdInputType> fwd_data(input);
std::vector<FwdOutputType> bwd_data(bwd_size, 0);

{
sycl::buffer<FwdInputType, 1> fwd_buf{ fwd_data };
sycl::buffer<FwdOutputType, 1> bwd_buf{ bwd_data };
sycl::buffer<FwdOutputType, 1> bwd_buf{ sycl::range<1>(backward_distance * batches) };

try {
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
Expand All @@ -60,9 +87,18 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {

{
auto acc_bwd = bwd_buf.template get_host_access();
EXPECT_TRUE(check_equal_vector(acc_bwd.get_pointer(), out_host_ref.data(),
bwd_data.size(), abs_error_margin, rel_error_margin,
std::cout));
auto bwd_ptr = acc_bwd.get_pointer();
auto ref_iter = out_host_ref.begin();
const auto ref_row_stride = sizes.back();
const auto backward_row_stride = get_backward_row_size<domain>(sizes);
const auto backward_row_elements = get_backward_row_size<domain>(sizes);

while (ref_iter < out_host_ref.end()) {
EXPECT_TRUE(check_equal_vector(bwd_ptr, ref_iter, backward_row_elements,
abs_error_margin, rel_error_margin, std::cout));
bwd_ptr += backward_row_stride;
ref_iter += ref_row_stride;
}
}

try {
Expand All @@ -88,24 +124,42 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
}
const std::vector<sycl::event> no_dependencies;

const size_t bwd_size = domain == oneapi::mkl::dft::domain::REAL ? (size / 2) + 1 : size;
const auto backward_distance = std::accumulate(
sizes.begin(), sizes.end() - 1, get_backward_row_size<domain>(sizes), std::multiplies<>());

descriptor_t descriptor{ size };
descriptor_t descriptor{ sizes };
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES,
complex_strides.data());
}
commit_descriptor(descriptor, sycl_queue);

descriptor_t descriptor_back{ size };
descriptor_t descriptor_back{ sizes };
descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size));
descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE,
(1.0 / forward_elements));
descriptor_back.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor_back.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements);
descriptor_back.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor_back.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES,
complex_strides.data());
}
commit_descriptor(descriptor_back, sycl_queue);

auto ua_input = usm_allocator_t<FwdInputType>(cxt, *dev);
auto ua_output = usm_allocator_t<FwdOutputType>(cxt, *dev);

std::vector<FwdInputType, decltype(ua_input)> fwd(input.begin(), input.end(), ua_input);
std::vector<FwdOutputType, decltype(ua_output)> bwd(bwd_size, ua_output);
std::vector<FwdOutputType, decltype(ua_output)> bwd(backward_distance * batches, ua_output);

try {
oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
Expand All @@ -117,8 +171,21 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
return test_skipped;
}

EXPECT_TRUE(check_equal_vector(bwd.data(), out_host_ref.data(), bwd.size(), abs_error_margin,
rel_error_margin, std::cout));
{
auto bwd_iter = bwd.begin();
auto ref_iter = out_host_ref.begin();

const auto ref_row_stride = sizes.back();
const auto backward_row_stride = get_backward_row_size<domain>(sizes);
const auto backward_row_elements = get_backward_row_size<domain>(sizes);

while (ref_iter < out_host_ref.end()) {
EXPECT_TRUE(check_equal_vector(bwd_iter, ref_iter, backward_row_elements,
abs_error_margin, rel_error_margin, std::cout));
bwd_iter += backward_row_stride;
ref_iter += ref_row_stride;
}
}

try {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor_back)>,
Expand Down
Loading