Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gesv_batch via gesv call #1877

Merged
merged 50 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
4c7c8c2
Init work
vlad-perevezentsev Jun 6, 2024
8e2cb23
First working version with transpose and C contig
vlad-perevezentsev Jun 7, 2024
67fa435
Second working version with moveaxis, transpose and F contig
vlad-perevezentsev Jun 7, 2024
4f5abec
Add more shape checks
vlad-perevezentsev Jun 11, 2024
0cb2808
Pass sycl::queue by reference for gesv/gesv_batch
vlad-perevezentsev Jun 11, 2024
bfa37d4
qwe
vlad-perevezentsev Jun 11, 2024
4a44292
Update _batched_solve implementation
vlad-perevezentsev Jun 12, 2024
df4774e
Remove old impl in _batched_solve
vlad-perevezentsev Jun 12, 2024
8dbe3c4
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jun 12, 2024
8fb2af3
Use py::gil_scoped_release before gesv call
vlad-perevezentsev Jun 12, 2024
ddcf9fe
Remove junk files
vlad-perevezentsev Jun 12, 2024
262794f
Move gesv_batch to gesv_batch.cpp
vlad-perevezentsev Jun 13, 2024
3a7b8ca
Improve gesv_batch with independent linear streams
vlad-perevezentsev Jun 13, 2024
2016a8c
Extend checks for gesv/gesv_batch
vlad-perevezentsev Jun 13, 2024
2c42290
Update comment
vlad-perevezentsev Jun 13, 2024
e030da8
junk files
vlad-perevezentsev Jun 14, 2024
3f99ae5
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jun 17, 2024
a0a683b
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 11, 2024
5a48f33
Add common_gesv_checks
vlad-perevezentsev Jul 12, 2024
924fee7
Release GIL in gesv_batch_impl
vlad-perevezentsev Jul 12, 2024
2b15e6c
Remove junk file
vlad-perevezentsev Jul 12, 2024
5a1cab6
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 12, 2024
b5c3062
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 15, 2024
afca803
Remove junk files
vlad-perevezentsev Jul 16, 2024
ed99888
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 16, 2024
0c97aff
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 19, 2024
1b275ea
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 26, 2024
e5b53a1
Remove host_task_events from gesv
vlad-perevezentsev Jul 26, 2024
d5adbd6
Use check_zeros_shape in gesv and gesv_batch
vlad-perevezentsev Jul 26, 2024
5b2780c
Add additional checks for gesv_impl
vlad-perevezentsev Jul 26, 2024
d4547d4
Move alloc_scratchpad to common_helpers.hpp
vlad-perevezentsev Jul 26, 2024
6759164
Use helper::alloc_scratchpad in gesv_batch_impl
vlad-perevezentsev Jul 26, 2024
f37ec43
Remove current_scratch_gesv check
vlad-perevezentsev Jul 26, 2024
adc17ba
Remove lda, ldb pass to gesv_batch_impl, gesv_impl
vlad-perevezentsev Jul 26, 2024
77ba0e2
Use const and constexpr in gesv/gesv_batch
vlad-perevezentsev Jul 26, 2024
9bf94b5
Applied review comments
vlad-perevezentsev Jul 29, 2024
b81893c
Use dpnp.reshape in _batched_solve
vlad-perevezentsev Jul 29, 2024
f8d68ef
Implement alloc_ipiv in common_helpers.hpp
vlad-perevezentsev Jul 29, 2024
fc6c7fa
Add gesv_common_utils.hpp
vlad-perevezentsev Jul 29, 2024
75079d2
Implement handle_lapack_exc function
vlad-perevezentsev Jul 29, 2024
6e82632
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 29, 2024
7e0f384
Use try/catch for scratchpad/ipiv allocation
vlad-perevezentsev Jul 29, 2024
f5ee368
Update alloc_scratchpad/alloc_ipiv
vlad-perevezentsev Jul 29, 2024
eb8c3a0
gesv_scratchpad_size can be 0
vlad-perevezentsev Jul 30, 2024
3c8cda6
Implement help functions alloc_ipiv/alloc_scratchpad
vlad-perevezentsev Jul 30, 2024
3f4d672
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 30, 2024
e56e07e
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 2, 2024
629b97a
Reuse alloc_scratchpad/ipiv in batch versions
vlad-perevezentsev Aug 2, 2024
a9cc253
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 6, 2024
3786ca2
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/geqrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/geqrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
Expand Down
89 changes: 88 additions & 1 deletion dpnp/backend/extensions/lapack/common_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
//*****************************************************************************

#pragma once
#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

#include <complex>
#include <cstring>
#include <pybind11/pybind11.h>
#include <stdexcept>

namespace dpnp::extensions::lapack::helper
Expand Down Expand Up @@ -63,4 +65,89 @@ inline bool check_zeros_shape(int ndim, const py::ssize_t *shape)
}
return src_nelems == 0;
}

// Allocate the memory for the pivot indices
inline std::int64_t *alloc_ipiv(const std::int64_t n, sycl::queue &exec_q)
{
std::int64_t *ipiv = nullptr;

try {
ipiv = sycl::malloc_device<std::int64_t>(n, exec_q);
if (!ipiv) {
throw std::runtime_error("Device allocation for ipiv failed");
}
} catch (sycl::exception const &e) {
if (ipiv != nullptr)
sycl::free(ipiv, exec_q);
throw std::runtime_error(
std::string(
"Unexpected SYCL exception caught during ipiv allocation: ") +
e.what());
}

return ipiv;
}

// Allocate the total memory for the total pivot indices with proper alignment
// for batch implementations
template <typename T>
inline std::int64_t *alloc_ipiv_batch(const std::int64_t n,
std::int64_t n_linear_streams,
sycl::queue &exec_q)
{
// Get padding size to ensure memory allocations are aligned to 256 bytes
// for better performance
const std::int64_t padding = 256 / sizeof(T);

// Calculate the total size needed for the pivot indices array for all
// linear streams with proper alignment
size_t alloc_ipiv_size = round_up_mult(n_linear_streams * n, padding);

return alloc_ipiv(alloc_ipiv_size, exec_q);
}

// Allocate the memory for the scratchpad
template <typename T>
inline T *alloc_scratchpad(std::int64_t scratchpad_size, sycl::queue &exec_q)
{
T *scratchpad = nullptr;

try {
if (scratchpad_size > 0) {
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
if (!scratchpad) {
throw std::runtime_error(
"Device allocation for scratchpad failed");
}
}
} catch (sycl::exception const &e) {
if (scratchpad != nullptr) {
sycl::free(scratchpad, exec_q);
}
throw std::runtime_error(std::string("Unexpected SYCL exception caught "
"during scratchpad allocation: ") +
e.what());
}

return scratchpad;
}

// Allocate the total scratchpad memory with proper alignment for batch
// implementations
template <typename T>
inline T *alloc_scratchpad_batch(std::int64_t scratchpad_size,
std::int64_t n_linear_streams,
sycl::queue &exec_q)
{
// Get padding size to ensure memory allocations are aligned to 256 bytes
// for better performance
const std::int64_t padding = 256 / sizeof(T);

// Calculate the total scratchpad memory size needed for all linear
// streams with proper alignment
const size_t alloc_scratch_size =
round_up_mult(n_linear_streams * scratchpad_size, padding);

return alloc_scratchpad<T>(alloc_scratch_size, exec_q);
}
} // namespace dpnp::extensions::lapack::helper
30 changes: 0 additions & 30 deletions dpnp/backend/extensions/lapack/evd_batch_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,34 +119,4 @@ std::pair<sycl::event, sycl::event>

return std::make_pair(ht_ev, evd_batch_ev);
}

template <typename T>
inline T *alloc_scratchpad(std::int64_t scratchpad_size,
std::int64_t n_linear_streams,
sycl::queue &exec_q)
{
// Get padding size to ensure memory allocations are aligned to 256 bytes
// for better performance
const std::int64_t padding = 256 / sizeof(T);

if (scratchpad_size <= 0) {
throw std::runtime_error(
"Invalid scratchpad size: must be greater than zero."
" Calculated scratchpad size: " +
std::to_string(scratchpad_size));
}

// Calculate the total scratchpad memory size needed for all linear
// streams with proper alignment
const size_t alloc_scratch_size =
helper::round_up_mult(n_linear_streams * scratchpad_size, padding);

// Allocate memory for the total scratchpad
T *scratchpad = sycl::malloc_device<T>(alloc_scratch_size, exec_q);
if (!scratchpad) {
throw std::runtime_error("Device allocation for scratchpad failed");
}

return scratchpad;
}
} // namespace dpnp::extensions::lapack::evd
Loading
Loading