Skip to content

Commit 08ab7e0

Browse files
committed
address comments
1 parent a9941bb commit 08ab7e0

File tree

3 files changed

+72
-55
lines changed

3 files changed

+72
-55
lines changed

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 48 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ template <typename T>
212212
struct CeilOutputType
213213
{
214214
using value_type = typename std::disjunction<
215-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
216-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
215+
dpctl_td_ns::TypeMapResultEntry<T, double>,
216+
dpctl_td_ns::TypeMapResultEntry<T, float>,
217217
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
218218
};
219219

@@ -227,10 +227,8 @@ template <typename T>
227227
struct ConjOutputType
228228
{
229229
using value_type = typename std::disjunction<
230-
dpctl_td_ns::
231-
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
232-
dpctl_td_ns::
233-
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
230+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
231+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
234232
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
235233
};
236234

@@ -246,8 +244,8 @@ struct CosOutputType
246244
using value_type = typename std::disjunction<
247245
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
248246
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
249-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
250-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
247+
dpctl_td_ns::TypeMapResultEntry<T, double>,
248+
dpctl_td_ns::TypeMapResultEntry<T, float>,
251249
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
252250
};
253251

@@ -263,8 +261,8 @@ struct CoshOutputType
263261
using value_type = typename std::disjunction<
264262
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
265263
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
266-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
267-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
264+
dpctl_td_ns::TypeMapResultEntry<T, double>,
265+
dpctl_td_ns::TypeMapResultEntry<T, float>,
268266
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
269267
};
270268

@@ -303,8 +301,10 @@ template <typename T>
303301
struct ExpOutputType
304302
{
305303
using value_type = typename std::disjunction<
306-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
307-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
304+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
305+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
306+
dpctl_td_ns::TypeMapResultEntry<T, double>,
307+
dpctl_td_ns::TypeMapResultEntry<T, float>,
308308
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
309309
};
310310

@@ -318,8 +318,8 @@ template <typename T>
318318
struct Expm1OutputType
319319
{
320320
using value_type = typename std::disjunction<
321-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
322-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
321+
dpctl_td_ns::TypeMapResultEntry<T, double>,
322+
dpctl_td_ns::TypeMapResultEntry<T, float>,
323323
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
324324
};
325325

@@ -333,8 +333,8 @@ template <typename T>
333333
struct FloorOutputType
334334
{
335335
using value_type = typename std::disjunction<
336-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
337-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
336+
dpctl_td_ns::TypeMapResultEntry<T, double>,
337+
dpctl_td_ns::TypeMapResultEntry<T, float>,
338338
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
339339
};
340340

@@ -363,61 +363,57 @@ template <typename T>
363363
struct LnOutputType
364364
{
365365
using value_type = typename std::disjunction<
366-
dpctl_td_ns::
367-
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
368-
dpctl_td_ns::
369-
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
370-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
371-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
366+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
367+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
368+
dpctl_td_ns::TypeMapResultEntry<T, double>,
369+
dpctl_td_ns::TypeMapResultEntry<T, float>,
372370
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
373371
};
374372

375373
/**
376374
* @brief A factory to define pairs of supported types for which
377-
* MKL VM library provides support in oneapi::mkl::vm::log1p<T> function.
375+
* MKL VM library provides support in oneapi::mkl::vm::log10<T> function.
378376
*
379377
* @tparam T Type of input vector `a` and of result vector `y`.
380378
*/
381379
template <typename T>
382-
struct Log1pOutputType
380+
struct Log10OutputType
383381
{
384382
using value_type = typename std::disjunction<
385-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
386-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
383+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
384+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
385+
dpctl_td_ns::TypeMapResultEntry<T, double>,
386+
dpctl_td_ns::TypeMapResultEntry<T, float>,
387387
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
388388
};
389389

390390
/**
391391
* @brief A factory to define pairs of supported types for which
392-
* MKL VM library provides support in oneapi::mkl::vm::log2<T> function.
392+
* MKL VM library provides support in oneapi::mkl::vm::log1p<T> function.
393393
*
394394
* @tparam T Type of input vector `a` and of result vector `y`.
395395
*/
396396
template <typename T>
397-
struct Log2OutputType
397+
struct Log1pOutputType
398398
{
399399
using value_type = typename std::disjunction<
400-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
401-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
400+
dpctl_td_ns::TypeMapResultEntry<T, double>,
401+
dpctl_td_ns::TypeMapResultEntry<T, float>,
402402
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
403403
};
404404

405405
/**
406406
* @brief A factory to define pairs of supported types for which
407-
* MKL VM library provides support in oneapi::mkl::vm::log10<T> function.
407+
* MKL VM library provides support in oneapi::mkl::vm::log2<T> function.
408408
*
409409
* @tparam T Type of input vector `a` and of result vector `y`.
410410
*/
411411
template <typename T>
412-
struct Log10OutputType
412+
struct Log2OutputType
413413
{
414414
using value_type = typename std::disjunction<
415-
dpctl_td_ns::
416-
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
417-
dpctl_td_ns::
418-
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
419-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
420-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
415+
dpctl_td_ns::TypeMapResultEntry<T, double>,
416+
dpctl_td_ns::TypeMapResultEntry<T, float>,
421417
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
422418
};
423419

@@ -481,8 +477,8 @@ template <typename T>
481477
struct RoundOutputType
482478
{
483479
using value_type = typename std::disjunction<
484-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
485-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
480+
dpctl_td_ns::TypeMapResultEntry<T, double>,
481+
dpctl_td_ns::TypeMapResultEntry<T, float>,
486482
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
487483
};
488484

@@ -498,8 +494,8 @@ struct SinOutputType
498494
using value_type = typename std::disjunction<
499495
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
500496
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
501-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
502-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
497+
dpctl_td_ns::TypeMapResultEntry<T, double>,
498+
dpctl_td_ns::TypeMapResultEntry<T, float>,
503499
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
504500
};
505501

@@ -515,8 +511,8 @@ struct SinhOutputType
515511
using value_type = typename std::disjunction<
516512
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
517513
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
518-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
519-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
514+
dpctl_td_ns::TypeMapResultEntry<T, double>,
515+
dpctl_td_ns::TypeMapResultEntry<T, float>,
520516
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
521517
};
522518

@@ -530,8 +526,8 @@ template <typename T>
530526
struct SqrOutputType
531527
{
532528
using value_type = typename std::disjunction<
533-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
534-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
529+
dpctl_td_ns::TypeMapResultEntry<T, double>,
530+
dpctl_td_ns::TypeMapResultEntry<T, float>,
535531
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
536532
};
537533

@@ -545,12 +541,10 @@ template <typename T>
545541
struct SqrtOutputType
546542
{
547543
using value_type = typename std::disjunction<
548-
dpctl_td_ns::
549-
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
550-
dpctl_td_ns::
551-
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
552-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
553-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
544+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
545+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
546+
dpctl_td_ns::TypeMapResultEntry<T, double>,
547+
dpctl_td_ns::TypeMapResultEntry<T, float>,
554548
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
555549
};
556550

@@ -623,8 +617,8 @@ template <typename T>
623617
struct TruncOutputType
624618
{
625619
using value_type = typename std::disjunction<
626-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
627-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
620+
dpctl_td_ns::TypeMapResultEntry<T, double>,
621+
dpctl_td_ns::TypeMapResultEntry<T, float>,
628622
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
629623
};
630624

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ static unary_impl_fn_ptr_t expm1_dispatch_vector[dpctl_td_ns::num_types];
9191
static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
9292
static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types];
9393
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
94+
static unary_impl_fn_ptr_t log10_dispatch_vector[dpctl_td_ns::num_types];
9495
static unary_impl_fn_ptr_t log1p_dispatch_vector[dpctl_td_ns::num_types];
9596
static unary_impl_fn_ptr_t log2_dispatch_vector[dpctl_td_ns::num_types];
96-
static unary_impl_fn_ptr_t log10_dispatch_vector[dpctl_td_ns::num_types];
9797
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
9898
static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types];
9999
static unary_impl_fn_ptr_t round_dispatch_vector[dpctl_td_ns::num_types];

tests/test_umath.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,29 @@ def test_exp(self, dtype):
531531
tol = numpy.finfo(dtype=result.dtype).resolution
532532
assert_allclose(expected, result.asnumpy(), rtol=tol)
533533

534+
@pytest.mark.parametrize("dtype", get_complex_dtypes())
535+
def test_exp_complex(self, dtype):
536+
x1 = numpy.linspace(0, 8, num=10)
537+
x2 = numpy.linspace(0, 6, num=10)
538+
Xnp = x1 + 1j * x2
539+
np_array = numpy.asarray(Xnp, dtype=dtype)
540+
np_out = numpy.empty(10, dtype=numpy.complex128)
541+
542+
# DPNP
543+
dp_out_dtype = dpnp.complex64
544+
if has_support_aspect64() and dtype != dpnp.complex64:
545+
dp_out_dtype = dpnp.complex128
546+
547+
dp_array = dpnp.array(np_array, dtype=dp_out_dtype)
548+
dp_out = dpnp.array(np_out, dtype=dp_out_dtype)
549+
result = dpnp.exp(dp_array, out=dp_out)
550+
551+
# original
552+
expected = numpy.exp(np_array, out=np_out)
553+
554+
tol = numpy.finfo(dtype=result.dtype).resolution
555+
assert_allclose(expected, result.asnumpy(), rtol=tol)
556+
534557
@pytest.mark.parametrize(
535558
"dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1]
536559
)

0 commit comments

Comments
 (0)