diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index a3ebbc6124c..3c2e18a5d4c 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -56,6 +56,7 @@ "atleast_3d", "broadcast_arrays", "broadcast_to", + "can_cast", "concatenate", "copyto", "expand_dims", @@ -402,6 +403,47 @@ def broadcast_to(array, /, shape, subok=False): return dpnp_array._create_from_usm_ndarray(new_array) +def can_cast(from_, to, casting="safe"): + """ + Returns ``True`` if cast between data types can occur according to the casting rule. + + If `from` is a scalar or array scalar, also returns ``True`` if the scalar value can + be cast without overflow or truncation to an integer. + + For full documentation refer to :obj:`numpy.can_cast`. + + Parameters + ---------- + from : dpnp.array, dtype + Source data type. + to : dtype + Target data type. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + + Returns + ------- + out: bool + True if cast can occur according to the casting rule. + + See Also + -------- + :obj:`dpnp.result_type` : Returns the type that results from applying the NumPy + type promotion rules to the arguments. + + """ + + if dpnp.is_supported_array_type(to): + raise TypeError("Cannot construct a dtype from an array") + + dtype_from = ( + from_.dtype + if dpnp.is_supported_array_type(from_) + else dpnp.dtype(from_) + ) + return dpt.can_cast(dtype_from, to, casting) + + def concatenate( arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind" ): @@ -519,7 +561,7 @@ def copyto(dst, src, casting="same_kind", where=True): elif not dpnp.is_supported_array_type(src): src = dpnp.array(src, sycl_queue=dst.sycl_queue) - if not dpt.can_cast(src.dtype, dst.dtype, casting=casting): + if not dpnp.can_cast(src.dtype, dst.dtype, casting=casting): raise TypeError( f"Cannot cast from {src.dtype} to {dst.dtype} " f"according to the rule {casting}." diff --git a/tests/test_arraymanipulation.py b/tests/test_arraymanipulation.py index 34e673cb82f..bf61634e52c 100644 --- a/tests/test_arraymanipulation.py +++ b/tests/test_arraymanipulation.py @@ -928,3 +928,14 @@ def test_subok_error(): with pytest.raises(NotImplementedError): dpnp.broadcast_arrays(x, subok=True) dpnp.broadcast_to(x, (4, 4), subok=True) + + +def test_can_cast(): + X = dpnp.ones((2, 2), dtype=dpnp.int64) + pytest.raises(TypeError, dpnp.can_cast, X, 1) + pytest.raises(TypeError, dpnp.can_cast, X, X) + + X_np = numpy.ones((2, 2), dtype=numpy.int64) + assert dpnp.can_cast(X, "float32") == numpy.can_cast(X_np, "float32") + assert dpnp.can_cast(X, dpnp.int32) == numpy.can_cast(X_np, numpy.int32) + assert dpnp.can_cast(X, dpnp.int64) == numpy.can_cast(X_np, numpy.int64) diff --git a/tests/third_party/cupy/test_type_routines.py b/tests/third_party/cupy/test_type_routines.py new file mode 100644 index 00000000000..6a274158bcd --- /dev/null +++ b/tests/third_party/cupy/test_type_routines.py @@ -0,0 +1,102 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from tests.third_party.cupy import testing + + +def _generate_type_routines_input(xp, dtype, obj_type): + dtype = numpy.dtype(dtype) + if obj_type == "dtype": + return dtype + if obj_type == "specifier": + return str(dtype) + if obj_type == "scalar": + return dtype.type(3) + if obj_type == "array": + return xp.zeros(3, dtype=dtype) + if obj_type == "primitive": + return type(dtype.type(3).tolist()) + assert False + + +@testing.parameterize( + *testing.product( + { + "obj_type": ["dtype", "specifier", "scalar", "array", "primitive"], + } + ) +) +class TestCanCast(unittest.TestCase): + @testing.for_all_dtypes_combination(names=("from_dtype", "to_dtype")) + @testing.numpy_cupy_equal() + def test_can_cast(self, xp, from_dtype, to_dtype): + if self.obj_type == "scalar": + pytest.skip("to be aligned with NEP-50") + + from_obj = _generate_type_routines_input(xp, from_dtype, self.obj_type) + + ret = xp.can_cast(from_obj, to_dtype) + assert isinstance(ret, bool) + return ret + + +@pytest.mark.skip("dpnp.common_type() is not implemented yet") +class TestCommonType(unittest.TestCase): + @testing.numpy_cupy_equal() + def test_common_type_empty(self, xp): + ret = xp.common_type() + assert type(ret) == type + return ret + + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_equal() + def test_common_type_single_argument(self, xp, dtype): + array = _generate_type_routines_input(xp, dtype, "array") + ret = xp.common_type(array) + assert type(ret) == type + return ret + + @testing.for_all_dtypes_combination( + names=("dtype1", "dtype2"), no_bool=True + ) + @testing.numpy_cupy_equal() + def test_common_type_two_arguments(self, xp, dtype1, dtype2): + array1 = _generate_type_routines_input(xp, dtype1, "array") + array2 = _generate_type_routines_input(xp, dtype2, "array") + ret = xp.common_type(array1, array2) + assert type(ret) == type + return ret + + @testing.for_all_dtypes() + def test_common_type_bool(self, dtype): + for xp in (numpy, cupy): + array1 = _generate_type_routines_input(xp, dtype, "array") + array2 = _generate_type_routines_input(xp, "bool_", "array") + with pytest.raises(TypeError): + xp.common_type(array1, array2) + + +@testing.parameterize( + *testing.product( + { + "obj_type1": ["dtype", "specifier", "scalar", "array", "primitive"], + "obj_type2": ["dtype", "specifier", "scalar", "array", "primitive"], + } + ) +) +class TestResultType(unittest.TestCase): + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_equal() + def test_result_type(self, xp, dtype1, dtype2): + if "scalar" in {self.obj_type1, self.obj_type2}: + pytest.skip("to be aligned with NEP-50") + + input1 = _generate_type_routines_input(xp, dtype1, self.obj_type1) + + input2 = _generate_type_routines_input(xp, dtype2, self.obj_type2) + ret = xp.result_type(input1, input2) + assert isinstance(ret, numpy.dtype) + return ret