diff --git a/dpnp/backend/extensions/vm/types_matrix.hpp b/dpnp/backend/extensions/vm/types_matrix.hpp index 0e5c03dbe54..af790a8fe43 100644 --- a/dpnp/backend/extensions/vm/types_matrix.hpp +++ b/dpnp/backend/extensions/vm/types_matrix.hpp @@ -212,8 +212,8 @@ template struct CeilOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -227,10 +227,8 @@ template struct ConjOutputType { using value_type = typename std::disjunction< - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry>, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -246,8 +244,8 @@ struct CosOutputType using value_type = typename std::disjunction< dpctl_td_ns::TypeMapResultEntry>, dpctl_td_ns::TypeMapResultEntry>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -263,8 +261,8 @@ struct CoshOutputType using value_type = typename std::disjunction< dpctl_td_ns::TypeMapResultEntry>, dpctl_td_ns::TypeMapResultEntry>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -303,8 +301,10 @@ template struct ExpOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -318,8 +318,8 @@ template struct Expm1OutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -333,8 +333,8 @@ template struct FloorOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -363,61 +363,57 @@ template struct LnOutputType { using value_type = typename std::disjunction< - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; /** * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::log1p function. + * MKL VM library provides support in oneapi::mkl::vm::log10 function. * * @tparam T Type of input vector `a` and of result vector `y`. */ template -struct Log1pOutputType +struct Log10OutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; /** * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::log2 function. + * MKL VM library provides support in oneapi::mkl::vm::log1p function. * * @tparam T Type of input vector `a` and of result vector `y`. */ template -struct Log2OutputType +struct Log1pOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; /** * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::log10 function. + * MKL VM library provides support in oneapi::mkl::vm::log2 function. * * @tparam T Type of input vector `a` and of result vector `y`. */ template -struct Log10OutputType +struct Log2OutputType { using value_type = typename std::disjunction< - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -481,8 +477,8 @@ template struct RoundOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -498,8 +494,8 @@ struct SinOutputType using value_type = typename std::disjunction< dpctl_td_ns::TypeMapResultEntry>, dpctl_td_ns::TypeMapResultEntry>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -515,8 +511,8 @@ struct SinhOutputType using value_type = typename std::disjunction< dpctl_td_ns::TypeMapResultEntry>, dpctl_td_ns::TypeMapResultEntry>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -530,8 +526,8 @@ template struct SqrOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -545,12 +541,10 @@ template struct SqrtOutputType { using value_type = typename std::disjunction< - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry>, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -623,8 +617,8 @@ template struct TruncOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; diff --git a/dpnp/backend/extensions/vm/vm_py.cpp b/dpnp/backend/extensions/vm/vm_py.cpp index 94030d8423a..a7dfce88a7a 100644 --- a/dpnp/backend/extensions/vm/vm_py.cpp +++ b/dpnp/backend/extensions/vm/vm_py.cpp @@ -91,9 +91,9 @@ static unary_impl_fn_ptr_t expm1_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types]; +static unary_impl_fn_ptr_t log10_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t log1p_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t log2_dispatch_vector[dpctl_td_ns::num_types]; -static unary_impl_fn_ptr_t log10_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t round_dispatch_vector[dpctl_td_ns::num_types]; diff --git a/tests/test_umath.py b/tests/test_umath.py index bb3d383c323..2b0db66ec0d 100644 --- a/tests/test_umath.py +++ b/tests/test_umath.py @@ -531,6 +531,29 @@ def test_exp(self, dtype): tol = numpy.finfo(dtype=result.dtype).resolution assert_allclose(expected, result.asnumpy(), rtol=tol) + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + def test_exp_complex(self, dtype): + x1 = numpy.linspace(0, 8, num=10) + x2 = numpy.linspace(0, 6, num=10) + Xnp = x1 + 1j * x2 + np_array = numpy.asarray(Xnp, dtype=dtype) + np_out = numpy.empty(10, dtype=numpy.complex128) + + # DPNP + dp_out_dtype = dpnp.complex64 + if has_support_aspect64() and dtype != dpnp.complex64: + dp_out_dtype = dpnp.complex128 + + dp_array = dpnp.array(np_array, dtype=dp_out_dtype) + dp_out = dpnp.array(np_out, dtype=dp_out_dtype) + result = dpnp.exp(dp_array, out=dp_out) + + # original + expected = numpy.exp(np_array, out=np_out) + + tol = numpy.finfo(dtype=result.dtype).resolution + assert_allclose(expected, result.asnumpy(), rtol=tol) + @pytest.mark.parametrize( "dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1] )