Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test_linalg.py to run on Iris Xe #1474

Merged
Merged
32 changes: 28 additions & 4 deletions dpnp/backend/kernels/dpnp_krnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,9 +1152,21 @@ void func_map_init_linalg(func_map_t &fmap)
eft_DBL, (void *)dpnp_eig_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_eig_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eig_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eig_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_eig_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eig_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eig_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_eig_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -1170,9 +1182,21 @@ void func_map_init_linalg(func_map_t &fmap)
eft_DBL, (void *)dpnp_eigvals_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_eigvals_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eigvals_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eigvals_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_eigvals_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eigvals_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eigvals_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_eigvals_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_DBL][eft_DBL] = {
Expand Down
60 changes: 52 additions & 8 deletions dpnp/backend/kernels/dpnp_krnl_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,16 +874,28 @@ void func_map_init_linalg_func(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_INV][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_inv_default_c<int64_t, double>};
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {
eft_DBL, (void *)dpnp_inv_default_c<float, double>};
eft_DBL, (void *)dpnp_inv_default_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_inv_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_inv_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_inv_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_inv_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_inv_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_inv_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_inv_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_FLT][eft_FLT] = {
eft_DBL, (void *)dpnp_inv_ext_c<float, double>};
eft_FLT, (void *)dpnp_inv_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_inv_ext_c<double, double>};

Expand Down Expand Up @@ -1039,9 +1051,21 @@ void func_map_init_linalg_func(func_map_t &fmap)
// eft_C128, (void*)dpnp_qr_c<std::complex<double>, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_qr_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_qr_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_qr_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_qr_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_qr_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_qr_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_qr_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -1062,9 +1086,29 @@ void func_map_init_linalg_func(func_map_t &fmap)
std::complex<double>, double>};

fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_svd_ext_c<int32_t, double, double>};
get_default_floating_type<>(),
(void *)dpnp_svd_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>,
func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)
dpnp_svd_ext_c<int32_t,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_svd_ext_c<int64_t, double, double>};
get_default_floating_type<>(),
(void *)dpnp_svd_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>,
func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)
dpnp_svd_ext_c<int64_t,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_svd_ext_c<float, float, float>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = {
Expand Down
11 changes: 11 additions & 0 deletions dpnp/backend/src/dpnp_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,17 @@ class dpnp_less_comp
}
};

/**
* A template function that determines the default floating-point type
* based on the value of the template parameter has_fp64.
*/
template <typename has_fp64 = std::true_type>
static constexpr DPNPFuncType get_default_floating_type()
{
return has_fp64::value ? DPNPFuncType::DPNP_FT_DOUBLE
: DPNPFuncType::DPNP_FT_FLOAT;
}

/**
* FPTR interface initialization functions
*/
Expand Down
2 changes: 2 additions & 0 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ cdef extern from "dpnp_iface_fptr.hpp":
struct DPNPFuncData:
DPNPFuncType return_type
void * ptr
DPNPFuncType return_type_no_fp64
void *ptr_no_fp64

DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except +

Expand Down
Loading