Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DFT][rocFFt] Address rocFFT failing tests #563

Merged
merged 7 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/dft/backends/rocfft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ find_package(HIP REQUIRED)
# Require the minimum rocFFT version matching with ROCm 5.4.3.
find_package(rocfft 1.0.21 REQUIRED)

if (${rocfft_VERSION_MAJOR} EQUAL "1" AND ${rocfft_VERSION_MINOR} EQUAL "0"
AND ((${rocfft_VERSION_PATCH} GREATER "22")
AND (${rocfft_VERSION_PATCH} LESS "31") ))
message(WARNING "Due to a bug in rocFFT some tests fail with the version in\
use. If possible use a version greater of 1.0.30 or less of 1.0.23.
Current rocFFT version ${rocfft_VERSION}")
endif()
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved

target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocfft)

# Allow to compile for different ROCm versions. See the README for the supported
Expand All @@ -62,6 +70,7 @@ find_path(
NO_DEFAULT_PATH
REQUIRED
)

Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
target_include_directories(${LIB_OBJ} PRIVATE ${rocfft_EXTRA_INCLUDE_DIR})

target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL)
Expand Down
48 changes: 38 additions & 10 deletions src/dft/backends/rocfft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "rocfft_handle.hpp"

#include <rocfft.h>
#include <rocfft-version.h>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header is now needed to check rocFFT version if necessary

#include <hip/hip_runtime_api.h>

namespace oneapi::mkl::dft::rocfft {
Expand Down Expand Up @@ -259,12 +260,37 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
std::reverse(stride_vecs.vec_b.begin(), stride_vecs.vec_b.end());
stride_vecs.vec_b.pop_back(); // Offset is not included.

rocfft_plan_description plan_desc;
if (rocfft_plan_description_create(&plan_desc) != rocfft_status_success) {
// This workaround is needed due to a confirmed issue in rocFFT from version
// 1.0.23 to 1.0.30. Those rocFFT version correspond to rocm version from
// 5.6.0 to 6.3.0.
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
// Link to rocFFT issue: https://github.com/ROCm/rocFFT/issues/507
if constexpr (rocfft_version_major == 1 && rocfft_version_minor == 0 &&
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if-statement could be converted to #if directive. I preferred to use this style instead, but I am open to change it

(rocfft_version_patch > 22 && rocfft_version_patch < 31)) {
if (dom == dft::domain::COMPLEX && dimensions > 2) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should allow in_place routine to run

Suggested change
if (dom == dft::domain::COMPLEX && dimensions > 2) {
if (dom == dft::domain::COMPLEX && config_values.placement == dft::config_value::NOT_INPLACE && dimensions > 2) {

auto stride_checker = [&](const auto& a, const auto& b) {
for (ulong i = 0; i < dimensions; ++i) {
if (a[i] != b[i])
return false;
}
return true;
};
std::printf("hello\n");
if (!stride_checker(stride_vecs.vec_a, stride_vecs.vec_b))
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
throw oneapi::mkl::unimplemented(
"DFT", func,
"due to a bug in rocfft version in use, it requires fwd and bwd stride to be the same for COMPLEX out_of_place computations");
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
}
}

rocfft_plan_description plan_desc_fwd, plan_desc_bwd; // Can't reuse with ROCm 6 due to bug.
if (rocfft_plan_description_create(&plan_desc_fwd) != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to create plan description.");
}
if (rocfft_plan_description_create(&plan_desc_bwd) != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to create plan description.");
}

// plan_description can be destroyed afted plan_create
auto description_destroy = [](rocfft_plan_description p) {
if (rocfft_plan_description_destroy(p) != rocfft_status_success) {
Expand All @@ -273,7 +299,9 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
}
};
std::unique_ptr<rocfft_plan_description_t, decltype(description_destroy)>
description_destroyer(plan_desc, description_destroy);
description_destroyer_fwd(plan_desc_fwd, description_destroy);
std::unique_ptr<rocfft_plan_description_t, decltype(description_destroy)>
description_destroyer_bwd(plan_desc_bwd, description_destroy);

std::array<std::size_t, 3> stride_a_indices{ 0, 1, 2 };
std::sort(&stride_a_indices[0], &stride_a_indices[dimensions],
Expand Down Expand Up @@ -324,7 +352,7 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {

if (valid_forward) {
auto res =
rocfft_plan_description_set_data_layout(plan_desc, fwd_array_ty, bwd_array_ty,
rocfft_plan_description_set_data_layout(plan_desc_fwd, fwd_array_ty, bwd_array_ty,
nullptr, // in offsets
nullptr, // out offsets
dimensions,
Expand All @@ -339,15 +367,15 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
"Failed to set forward data layout.");
}

if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.fwd_scale) !=
if (rocfft_plan_description_set_scale_factor(plan_desc_fwd, config_values.fwd_scale) !=
rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to set forward scale factor.");
}

rocfft_plan fwd_plan;
res = rocfft_plan_create(&fwd_plan, placement, fwd_type, precision, dimensions,
lengths.data(), number_of_transforms, plan_desc);
lengths.data(), number_of_transforms, plan_desc_fwd);

if (res != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
Expand Down Expand Up @@ -380,7 +408,7 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {

if (valid_backward) {
auto res =
rocfft_plan_description_set_data_layout(plan_desc, bwd_array_ty, fwd_array_ty,
rocfft_plan_description_set_data_layout(plan_desc_bwd, bwd_array_ty, fwd_array_ty,
nullptr, // in offsets
nullptr, // out offsets
dimensions,
Expand All @@ -395,15 +423,15 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
"Failed to set backward data layout.");
}

if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.bwd_scale) !=
if (rocfft_plan_description_set_scale_factor(plan_desc_bwd, config_values.bwd_scale) !=
rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to set backward scale factor.");
}

rocfft_plan bwd_plan;
res = rocfft_plan_create(&bwd_plan, placement, bwd_type, precision, dimensions,
lengths.data(), number_of_transforms, plan_desc);
lengths.data(), number_of_transforms, plan_desc_bwd);
if (res != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to create backward rocFFT plan.");
Expand Down
8 changes: 2 additions & 6 deletions src/dft/backends/rocfft/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the namespace in this file is not consistent in all backends.
With this change rocFFT, oneMKL CPU, oneMKL GPU have the oneapi::mkl::dft::detail namespace but cuFFT and portFFT have oneapi::mkl::dft namespace.
I know this is unrelated to the original issue, but since you changed one of them, could you please make them consistent across all backends?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review @lhuot
No problem, I updated them in aed1f63


template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::rocfft> selector) {
Expand All @@ -46,6 +44,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
template void descriptor<precision::DOUBLE, domain::REAL>::commit(
backend_selector<backend::rocfft>);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
Loading