Skip to content

Commit 28376c5

Browse files
antonwolfynpolina4
andcommitted
Add dpnp.result_type() support (#1435)
* Add dpnp.result_type() support * Update dpnp/dpnp_iface_manipulation.py Co-authored-by: Natalia Polina <natalia.polina@intel.com> --------- Co-authored-by: Natalia Polina <natalia.polina@intel.com>
1 parent 0c7f196 commit 28376c5

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _call_divide_inplace(lhs, rhs, sycl_queue, depends=[]):
133133
"""In place workaround until dpctl.tensor provides the functionality."""
134134

135135
# allocate temporary memory for out array
136-
out = dpt.empty_like(lhs, dtype=numpy.result_type((lhs.dtype, rhs.dtype)))
136+
out = dpt.empty_like(lhs, dtype=dpnp.result_type(lhs.dtype, rhs.dtype))
137137

138138
# call a general callback
139139
div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends)

dpnp/dpnp_iface_manipulation.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"ravel",
6868
"repeat",
6969
"reshape",
70+
"result_type",
7071
"rollaxis",
7172
"shape",
7273
"squeeze",
@@ -579,6 +580,49 @@ def reshape(x, /, newshape, order='C', copy=None):
579580
return dpnp_array._create_from_usm_ndarray(usm_arr)
580581

581582

583+
def result_type(*arrays_and_dtypes):
584+
"""
585+
Returns the type that results from applying the NumPy
586+
type promotion rules to the arguments.
587+
588+
For full documentation refer to :obj:`numpy.result_type`.
589+
590+
Parameters
591+
----------
592+
arrays_and_dtypes : list of arrays and dtypes
593+
An arbitrary length sequence of arrays or dtypes.
594+
595+
Returns
596+
-------
597+
out : dtype
598+
The result type.
599+
600+
Limitations
601+
-----------
602+
An array in the input list is supported as either :class:`dpnp.ndarray`
603+
or :class:`dpctl.tensor.usm_ndarray`.
604+
605+
Examples
606+
--------
607+
>>> import dpnp as dp
608+
>>> dp.result_type(dp.arange(3, dtype=dp.int64), dp.arange(7, dtype=dp.int32))
609+
dtype('int64')
610+
611+
>>> dp.result_type(dp.int64, dp.complex128)
612+
dtype('complex128')
613+
614+
>>> dp.result_type(dp.ones(10, dtype=dp.float32), dp.float64)
615+
dtype('float64')
616+
617+
"""
618+
619+
usm_arrays_and_dtypes = [
620+
X.dtype if isinstance(X, (dpnp_array, dpt.usm_ndarray)) else X
621+
for X in arrays_and_dtypes
622+
]
623+
return dpt.result_type(*usm_arrays_and_dtypes)
624+
625+
582626
def rollaxis(x1, axis, start=0):
583627
"""
584628
Roll the specified axis backwards, until it lies in a given position.

tests/test_manipulation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ def test_repeat(arr):
4040
assert_array_equal(expected, result)
4141

4242

43+
def test_result_type():
44+
X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.int32, "float16"]
45+
X_np = [numpy.ones((2), dtype=numpy.int64), numpy.int32, "float16"]
46+
47+
assert dpnp.result_type(*X) == numpy.result_type(*X_np)
48+
49+
def test_result_type_only_dtypes():
50+
X = [dpnp.int64, dpnp.int32, dpnp.bool, dpnp.float32]
51+
X_np = [numpy.int64, numpy.int32, numpy.bool_, numpy.float32]
52+
53+
assert dpnp.result_type(*X) == numpy.result_type(*X_np)
54+
55+
def test_result_type_only_arrays():
56+
X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.ones((7, 4), dtype=dpnp.int32)]
57+
X_np = [numpy.ones((2), dtype=numpy.int64), numpy.ones((7, 4), dtype=numpy.int32)]
58+
59+
assert dpnp.result_type(*X) == numpy.result_type(*X_np)
60+
61+
4362
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
4463
@pytest.mark.parametrize("array",
4564
[[1, 2, 3],

0 commit comments

Comments
 (0)