Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Oct 10, 2023
1 parent a9941bb commit 08ab7e0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 55 deletions.
102 changes: 48 additions & 54 deletions dpnp/backend/extensions/vm/types_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ template <typename T>
struct CeilOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -227,10 +227,8 @@ template <typename T>
struct ConjOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -246,8 +244,8 @@ struct CosOutputType
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -263,8 +261,8 @@ struct CoshOutputType
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand Down Expand Up @@ -303,8 +301,10 @@ template <typename T>
struct ExpOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -318,8 +318,8 @@ template <typename T>
struct Expm1OutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -333,8 +333,8 @@ template <typename T>
struct FloorOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand Down Expand Up @@ -363,61 +363,57 @@ template <typename T>
struct LnOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::log1p<T> function.
* MKL VM library provides support in oneapi::mkl::vm::log10<T> function.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct Log1pOutputType
struct Log10OutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::log2<T> function.
* MKL VM library provides support in oneapi::mkl::vm::log1p<T> function.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct Log2OutputType
struct Log1pOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::log10<T> function.
* MKL VM library provides support in oneapi::mkl::vm::log2<T> function.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct Log10OutputType
struct Log2OutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand Down Expand Up @@ -481,8 +477,8 @@ template <typename T>
struct RoundOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -498,8 +494,8 @@ struct SinOutputType
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -515,8 +511,8 @@ struct SinhOutputType
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -530,8 +526,8 @@ template <typename T>
struct SqrOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -545,12 +541,10 @@ template <typename T>
struct SqrtOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand Down Expand Up @@ -623,8 +617,8 @@ template <typename T>
struct TruncOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

Expand Down
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/vm/vm_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ static unary_impl_fn_ptr_t expm1_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t log10_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t log1p_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t log2_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t log10_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t round_dispatch_vector[dpctl_td_ns::num_types];
Expand Down
23 changes: 23 additions & 0 deletions tests/test_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,29 @@ def test_exp(self, dtype):
tol = numpy.finfo(dtype=result.dtype).resolution
assert_allclose(expected, result.asnumpy(), rtol=tol)

@pytest.mark.parametrize("dtype", get_complex_dtypes())
def test_exp_complex(self, dtype):
x1 = numpy.linspace(0, 8, num=10)
x2 = numpy.linspace(0, 6, num=10)
Xnp = x1 + 1j * x2
np_array = numpy.asarray(Xnp, dtype=dtype)
np_out = numpy.empty(10, dtype=numpy.complex128)

# DPNP
dp_out_dtype = dpnp.complex64
if has_support_aspect64() and dtype != dpnp.complex64:
dp_out_dtype = dpnp.complex128

dp_array = dpnp.array(np_array, dtype=dp_out_dtype)
dp_out = dpnp.array(np_out, dtype=dp_out_dtype)
result = dpnp.exp(dp_array, out=dp_out)

# original
expected = numpy.exp(np_array, out=np_out)

tol = numpy.finfo(dtype=result.dtype).resolution
assert_allclose(expected, result.asnumpy(), rtol=tol)

@pytest.mark.parametrize(
"dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1]
)
Expand Down

0 comments on commit 08ab7e0

Please sign in to comment.