Skip to content

Commit

Permalink
Implemented dpnp.can_cast function (#1600)
Browse files Browse the repository at this point in the history
* Implemented dpnp.can_cast function

* address comments

* Update tests/third_party/cupy/test_type_routines.py

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>

* Update tests/third_party/cupy/test_type_routines.py

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
npolina4 and antonwolfy authored Oct 19, 2023
1 parent 7bbbf1a commit f58d19d
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 1 deletion.
44 changes: 43 additions & 1 deletion dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"atleast_3d",
"broadcast_arrays",
"broadcast_to",
"can_cast",
"concatenate",
"copyto",
"expand_dims",
Expand Down Expand Up @@ -402,6 +403,47 @@ def broadcast_to(array, /, shape, subok=False):
return dpnp_array._create_from_usm_ndarray(new_array)


def can_cast(from_, to, casting="safe"):
"""
Returns ``True`` if cast between data types can occur according to the casting rule.
If `from` is a scalar or array scalar, also returns ``True`` if the scalar value can
be cast without overflow or truncation to an integer.
For full documentation refer to :obj:`numpy.can_cast`.
Parameters
----------
from : dpnp.array, dtype
Source data type.
to : dtype
Target data type.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur.
Returns
-------
out: bool
True if cast can occur according to the casting rule.
See Also
--------
:obj:`dpnp.result_type` : Returns the type that results from applying the NumPy
type promotion rules to the arguments.
"""

if dpnp.is_supported_array_type(to):
raise TypeError("Cannot construct a dtype from an array")

dtype_from = (
from_.dtype
if dpnp.is_supported_array_type(from_)
else dpnp.dtype(from_)
)
return dpt.can_cast(dtype_from, to, casting)


def concatenate(
arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"
):
Expand Down Expand Up @@ -519,7 +561,7 @@ def copyto(dst, src, casting="same_kind", where=True):
elif not dpnp.is_supported_array_type(src):
src = dpnp.array(src, sycl_queue=dst.sycl_queue)

if not dpt.can_cast(src.dtype, dst.dtype, casting=casting):
if not dpnp.can_cast(src.dtype, dst.dtype, casting=casting):
raise TypeError(
f"Cannot cast from {src.dtype} to {dst.dtype} "
f"according to the rule {casting}."
Expand Down
11 changes: 11 additions & 0 deletions tests/test_arraymanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,3 +928,14 @@ def test_subok_error():
with pytest.raises(NotImplementedError):
dpnp.broadcast_arrays(x, subok=True)
dpnp.broadcast_to(x, (4, 4), subok=True)


def test_can_cast():
X = dpnp.ones((2, 2), dtype=dpnp.int64)
pytest.raises(TypeError, dpnp.can_cast, X, 1)
pytest.raises(TypeError, dpnp.can_cast, X, X)

X_np = numpy.ones((2, 2), dtype=numpy.int64)
assert dpnp.can_cast(X, "float32") == numpy.can_cast(X_np, "float32")
assert dpnp.can_cast(X, dpnp.int32) == numpy.can_cast(X_np, numpy.int32)
assert dpnp.can_cast(X, dpnp.int64) == numpy.can_cast(X_np, numpy.int64)
102 changes: 102 additions & 0 deletions tests/third_party/cupy/test_type_routines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest

import numpy
import pytest

import dpnp as cupy
from tests.third_party.cupy import testing


def _generate_type_routines_input(xp, dtype, obj_type):
dtype = numpy.dtype(dtype)
if obj_type == "dtype":
return dtype
if obj_type == "specifier":
return str(dtype)
if obj_type == "scalar":
return dtype.type(3)
if obj_type == "array":
return xp.zeros(3, dtype=dtype)
if obj_type == "primitive":
return type(dtype.type(3).tolist())
assert False


@testing.parameterize(
*testing.product(
{
"obj_type": ["dtype", "specifier", "scalar", "array", "primitive"],
}
)
)
class TestCanCast(unittest.TestCase):
@testing.for_all_dtypes_combination(names=("from_dtype", "to_dtype"))
@testing.numpy_cupy_equal()
def test_can_cast(self, xp, from_dtype, to_dtype):
if self.obj_type == "scalar":
pytest.skip("to be aligned with NEP-50")

from_obj = _generate_type_routines_input(xp, from_dtype, self.obj_type)

ret = xp.can_cast(from_obj, to_dtype)
assert isinstance(ret, bool)
return ret


@pytest.mark.skip("dpnp.common_type() is not implemented yet")
class TestCommonType(unittest.TestCase):
@testing.numpy_cupy_equal()
def test_common_type_empty(self, xp):
ret = xp.common_type()
assert type(ret) == type
return ret

@testing.for_all_dtypes(no_bool=True)
@testing.numpy_cupy_equal()
def test_common_type_single_argument(self, xp, dtype):
array = _generate_type_routines_input(xp, dtype, "array")
ret = xp.common_type(array)
assert type(ret) == type
return ret

@testing.for_all_dtypes_combination(
names=("dtype1", "dtype2"), no_bool=True
)
@testing.numpy_cupy_equal()
def test_common_type_two_arguments(self, xp, dtype1, dtype2):
array1 = _generate_type_routines_input(xp, dtype1, "array")
array2 = _generate_type_routines_input(xp, dtype2, "array")
ret = xp.common_type(array1, array2)
assert type(ret) == type
return ret

@testing.for_all_dtypes()
def test_common_type_bool(self, dtype):
for xp in (numpy, cupy):
array1 = _generate_type_routines_input(xp, dtype, "array")
array2 = _generate_type_routines_input(xp, "bool_", "array")
with pytest.raises(TypeError):
xp.common_type(array1, array2)


@testing.parameterize(
*testing.product(
{
"obj_type1": ["dtype", "specifier", "scalar", "array", "primitive"],
"obj_type2": ["dtype", "specifier", "scalar", "array", "primitive"],
}
)
)
class TestResultType(unittest.TestCase):
@testing.for_all_dtypes_combination(names=("dtype1", "dtype2"))
@testing.numpy_cupy_equal()
def test_result_type(self, xp, dtype1, dtype2):
if "scalar" in {self.obj_type1, self.obj_type2}:
pytest.skip("to be aligned with NEP-50")

input1 = _generate_type_routines_input(xp, dtype1, self.obj_type1)

input2 = _generate_type_routines_input(xp, dtype2, self.obj_type2)
ret = xp.result_type(input1, input2)
assert isinstance(ret, numpy.dtype)
return ret

0 comments on commit f58d19d

Please sign in to comment.