Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement of dpnp.linalg.slogdet() #1607

Merged
merged 66 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
308b8cf
Add a new impl of dpnp.linalg._lu_factor
vlad-perevezentsev Oct 23, 2023
1119341
Get dev_info_array after calling getrf
vlad-perevezentsev Oct 25, 2023
71c2e96
Add an extra dev_info return to _lu_factor
vlad-perevezentsev Oct 25, 2023
b24c8c5
qwe
vlad-perevezentsev Oct 26, 2023
500df36
Add a logic for a.ndim > 2 in _lu_factor
vlad-perevezentsev Oct 31, 2023
b35e282
Add an implementation of dpnp.linalg.slogdet
vlad-perevezentsev Oct 31, 2023
301fc2a
Add a new test_norms.py file in cupy tests
vlad-perevezentsev Oct 31, 2023
9aba014
Expand test scope in public CI
vlad-perevezentsev Oct 31, 2023
19d909c
Merge master into impl_lu_factor
vlad-perevezentsev Oct 31, 2023
2419eba
Merge master into impl_lu_factor
vlad-perevezentsev Nov 22, 2023
a8789e4
A small update _lu_factor func
vlad-perevezentsev Nov 22, 2023
59642f6
Remove w/a for dpnp.count_nonzero in slogdet
vlad-perevezentsev Nov 23, 2023
db45555
getrf returns pair of events and uses dpctl.utils.keep_args_alive
vlad-perevezentsev Nov 23, 2023
0350d86
Update dpnp.linalg.det using slogdet
vlad-perevezentsev Nov 27, 2023
6c20deb
Add new cupy tests for dpnp.linalg.det
vlad-perevezentsev Nov 27, 2023
b17802d
Merge master into impl_lu_factor
vlad-perevezentsev Nov 27, 2023
3359c05
Add ipiv_vecs and dev_info_vecs in _lu_factor for the batch case
vlad-perevezentsev Nov 27, 2023
1a91385
Skip test_det on CPU due to bug in MKL
vlad-perevezentsev Nov 27, 2023
b860790
Small update of cupy tests in test_norms.py
vlad-perevezentsev Nov 28, 2023
37d476b
Add support of complex dtype for dpnp.diagonal and update test_diagonal
vlad-perevezentsev Nov 29, 2023
2e8b4fe
lu_factor func returns the result of LU decomposition as c-contiguous…
vlad-perevezentsev Nov 29, 2023
3849d92
Add getrf_batch MKL extension
vlad-perevezentsev Nov 29, 2023
cd282a3
Update docstring for slogdet
vlad-perevezentsev Nov 30, 2023
771605f
Add more tests
vlad-perevezentsev Nov 30, 2023
35de575
Merge master into impl_lu_factor
vlad-perevezentsev Nov 30, 2023
cde627d
Remove accidentally added file
vlad-perevezentsev Nov 30, 2023
694870b
Modify sign parameter calculation
vlad-perevezentsev Nov 30, 2023
d85d00d
Remove the old backend implementation of dpnp_det
vlad-perevezentsev Nov 30, 2023
c4b9992
qwe
vlad-perevezentsev Nov 30, 2023
2ad0bc4
Merge master into impl_lu_factor
vlad-perevezentsev Dec 15, 2023
78b98e7
Keep lexographical order
vlad-perevezentsev Dec 15, 2023
46a9965
Add dpnp_slogdet to dpnp_utils_linalg
vlad-perevezentsev Dec 15, 2023
da16383
Move _lu_factor above
vlad-perevezentsev Dec 15, 2023
4e0e183
A minor update
vlad-perevezentsev Dec 15, 2023
80d8188
A minor changes for _lu_factor
vlad-perevezentsev Dec 15, 2023
628dd90
Remove trash files
vlad-perevezentsev Dec 15, 2023
a8db460
Use getrf_batch only on CPU
vlad-perevezentsev Dec 18, 2023
e3cd5c4
Update tests for dpnp.linalg.slogdet
vlad-perevezentsev Dec 18, 2023
40c7a29
Merge master into impl_lu_factor
vlad-perevezentsev Dec 18, 2023
4679637
Merge master into impl_lu_factor
vlad-perevezentsev Dec 20, 2023
99f3618
Address remarks
vlad-perevezentsev Dec 20, 2023
1f9b6fa
Add _real_type func
vlad-perevezentsev Dec 20, 2023
7e12063
Add test_det in test_usm_type
vlad-perevezentsev Dec 20, 2023
0e82258
Add more checks in getrf and getf_batch functions
vlad-perevezentsev Dec 21, 2023
3a6e5ce
Improve error handler in getrf_impl
vlad-perevezentsev Dec 21, 2023
6700fce
Improve error handler in getrf_batch_impl
vlad-perevezentsev Dec 21, 2023
644485b
dev_info is allocated as zeros
vlad-perevezentsev Dec 21, 2023
5015e15
Remove skipif for singular tests
vlad-perevezentsev Dec 21, 2023
68c436e
Implement _lu_factor logic with dev_info as a python list
vlad-perevezentsev Dec 21, 2023
030a083
Update getrf_rf error handler with mkl_lapack::batch_error
vlad-perevezentsev Dec 21, 2023
579b4e5
Remove passing n parameter to _getrf
vlad-perevezentsev Dec 22, 2023
acd04b7
Add a new test_slogdet_singular_matrix_3D test
vlad-perevezentsev Dec 22, 2023
0fe0bf1
Merge remote-tracking branch 'origin/master' into impl_lu_factor
vlad-perevezentsev Dec 22, 2023
f0cc4d3
Update tests for dpnp.linalg.det
vlad-perevezentsev Dec 22, 2023
8896aab
Merge master into impl_lu_factor
vlad-perevezentsev Dec 22, 2023
c9b7c3b
Use is_exception_caught flag in getrf and getrf_batch error handler
vlad-perevezentsev Jan 8, 2024
a3873cd
Update gesv error handler
vlad-perevezentsev Jan 8, 2024
9652797
Reshape results after calling getrf_batch
vlad-perevezentsev Jan 8, 2024
cef4690
Add a new dpnp.linalg.det impl and refresh dpnp_utils_linalg
vlad-perevezentsev Jan 9, 2024
67681bf
Remove Limitations from dpnp_det and dpnp_slogdet docstings
vlad-perevezentsev Jan 9, 2024
ed71f6b
Address remarks
vlad-perevezentsev Jan 10, 2024
e8f5fbd
Merge master into impl_lu_factor
vlad-perevezentsev Jan 10, 2024
75fa23a
Remove det_dtype variable and use the abs val of diag for det
vlad-perevezentsev Jan 11, 2024
e9bdcd6
Expand cupy tests for dpnp.linalg.det()
vlad-perevezentsev Jan 11, 2024
549b5da
Update TestDet and TestSlogdet
vlad-perevezentsev Jan 11, 2024
26277ba
Merge master into impl_lu_factor
vlad-perevezentsev Jan 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/core_tests
third_party/cupy/linalg_tests/test_norms.py
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/linalg_tests/test_solve.py
third_party/cupy/logic_tests/test_comparison.py
Expand Down
2 changes: 2 additions & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ set(python_module_name _lapack_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
)
Expand Down
17 changes: 9 additions & 8 deletions dpnp/backend/extensions/lapack/gesv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ static sycl::event gesv_impl(sycl::queue exec_q,

std::stringstream error_msg;
std::int64_t info = 0;
bool sycl_exception_caught = false;
bool is_exception_caught = false;

sycl::event gesv_event;
try {
Expand All @@ -106,12 +106,18 @@ static sycl::event gesv_impl(sycl::queue exec_q,
// routine for storing intermediate results.
scratchpad_size, depends);
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
info = e.info();

if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else if (info > 0) {
T host_U;
exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T))
Expand All @@ -131,23 +137,18 @@ static sycl::event gesv_impl(sycl::queue exec_q,
<< e.what() << "\ninfo: " << e.info();
}
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else {
error_msg << "Unexpected MKL exception caught during gesv() "
"call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during gesv() call:\n"
<< e.what();
sycl_exception_caught = true;
}

if (info != 0 || sycl_exception_caught) // an unexpected error occurs
if (is_exception_caught) // an unexpected error occurs
{
if (scratchpad != nullptr) {
sycl::free(scratchpad, exec_q);
Expand Down
245 changes: 245 additions & 0 deletions dpnp/backend/extensions/lapack/getrf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/type_utils.hpp"

#include "getrf.hpp"
#include "types_matrix.hpp"

#include "dpnp_utils.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue,
const std::int64_t,
char *,
std::int64_t,
std::int64_t *,
py::list,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event getrf_impl(sycl::queue exec_q,
const std::int64_t n,
char *in_a,
std::int64_t lda,
std::int64_t *ipiv,
py::list dev_info,
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *a = reinterpret_cast<T *>(in_a);

const std::int64_t scratchpad_size =
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda);
T *scratchpad = nullptr;

std::stringstream error_msg;
std::int64_t info = 0;
bool is_exception_caught = false;

sycl::event getrf_event;
try {
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);

getrf_event = mkl_lapack::getrf(
exec_q,
n, // The order of the square matrix A (0 ≤ n).
// It must be a non-negative integer.
n, // The number of columns in the square matrix A (0 ≤ n).
// It must be a non-negative integer.
a, // Pointer to the square matrix A (n x n).
lda, // The leading dimension of matrix A.
// It must be at least max(1, n).
ipiv, // Pointer to the output array of pivot indices.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
info = e.info();
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved

if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else if (info > 0) {
// Store the positive 'info' value in the first element of
// 'dev_info'. This indicates that the factorization has been
// completed, but the factor U (upper triangular matrix) is exactly
// singular. The 'info' value here is the index of the first zero
// element in the diagonal of U.
is_exception_caught = false;
dev_info[0] = info;
}
else {
error_msg << "Unexpected MKL exception caught during getrf() "
"call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during getrf() call:\n"
<< e.what();
}

if (is_exception_caught) // an unexpected error occurs
{
if (scratchpad != nullptr) {
sycl::free(scratchpad, exec_q);
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
}

throw std::runtime_error(error_msg.str());
}

sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(getrf_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
});
host_task_events.push_back(clean_up_event);
return getrf_event;
}

std::pair<sycl::event, sycl::event>
getrf(sycl::queue exec_q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray ipiv_array,
py::list dev_info,
const std::vector<sycl::event> &depends)
{
const int a_array_nd = a_array.get_ndim();
const int ipiv_array_nd = ipiv_array.get_ndim();

if (a_array_nd != 2) {
throw py::value_error(
"The input array has ndim=" + std::to_string(a_array_nd) +
", but a 2-dimensional array is expected.");
}

if (ipiv_array_nd != 1) {
throw py::value_error("The array of pivot indices has ndim=" +
std::to_string(ipiv_array_nd) +
", but a 1-dimensional array is expected.");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) {
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
throw py::value_error(
"Execution queue is not compatible with allocation queues");
}

bool is_a_array_c_contig = a_array.is_c_contiguous();
if (!is_a_array_c_contig) {
throw py::value_error("The input array "
"must be C-contiguous");
}
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved

auto array_types = dpctl_td_ns::usm_ndarray_types();
int a_array_type_id =
array_types.typenum_to_lookup_id(a_array.get_typenum());

getrf_impl_fn_ptr_t getrf_fn = getrf_dispatch_vector[a_array_type_id];
if (getrf_fn == nullptr) {
throw py::value_error(
"No getrf implementation defined for the provided type "
"of the input matrix.");
}

vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
int ipiv_array_type_id =
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());

if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
throw py::value_error("The type of 'ipiv_array' must be int64.");
}

const std::int64_t n = a_array.get_shape_raw()[0];

char *a_array_data = a_array.get_data();
const std::int64_t lda = std::max<size_t>(1UL, n);

vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
char *ipiv_array_data = ipiv_array.get_data();
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);

vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
std::vector<sycl::event> host_task_events;
sycl::event getrf_ev = getrf_fn(exec_q, n, a_array_data, lda, d_ipiv,
dev_info, host_task_events, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {a_array, ipiv_array}, host_task_events);

return std::make_pair(args_ev, getrf_ev);
}

template <typename fnT, typename T>
struct GetrfContigFactory
{
fnT get()
{
if constexpr (types::GetrfTypePairSupportFactory<T>::is_defined) {
return getrf_impl<T>;
}
else {
return nullptr;
}
}
};

void init_getrf_dispatch_vector(void)
{
dpctl_td_ns::DispatchVectorBuilder<getrf_impl_fn_ptr_t, GetrfContigFactory,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_vector(getrf_dispatch_vector);
}
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
64 changes: 64 additions & 0 deletions dpnp/backend/extensions/lapack/getrf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>
#include <oneapi/mkl.hpp>

#include <dpctl4pybind11.hpp>

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
extern std::pair<sycl::event, sycl::event>
getrf(sycl::queue exec_q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray ipiv_array,
py::list dev_info,
const std::vector<sycl::event> &depends = {});

extern std::pair<sycl::event, sycl::event>
getrf_batch(sycl::queue exec_q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray ipiv_array,
py::list dev_info,
std::int64_t n,
std::int64_t stride_a,
std::int64_t stride_ipiv,
std::int64_t batch_size,
const std::vector<sycl::event> &depends = {});

extern void init_getrf_dispatch_vector(void);
extern void init_getrf_batch_dispatch_vector(void);
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
Loading
Loading