Skip to content

Commit

Permalink
Merge pull request cupy#7651 from asi1024/deprecate-find_common_types
Browse files Browse the repository at this point in the history
Avoid using `numpy.find_common_type`
  • Loading branch information
kmaehashi authored and chainer-ci committed Aug 17, 2023
1 parent 8d2595c commit f3e28cc
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 14 deletions.
12 changes: 3 additions & 9 deletions cupy/_indexing/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy

import cupy
from cupy import _core
from cupy._creation import from_data
from cupy._manipulation import join

Expand Down Expand Up @@ -40,16 +39,14 @@ def __getitem__(self, key):
trans1d = self.trans1d
ndmin = self.ndmin
objs = []
arrays = []
scalars = []
arraytypes = []
scalartypes = []
if isinstance(key, str):
raise NotImplementedError
if not isinstance(key, tuple):
key = (key,)

for i, k in enumerate(key):
scalar = False
if isinstance(k, slice):
raise NotImplementedError
elif isinstance(k, str):
Expand All @@ -60,20 +57,17 @@ def __getitem__(self, key):
elif type(k) in numpy.ScalarType:
newobj = from_data.array(k, ndmin=ndmin)
scalars.append(i)
scalar = True
scalartypes.append(newobj.dtype)
else:
newobj = from_data.array(k, copy=False, ndmin=ndmin)
if ndmin > 1:
ndim = from_data.array(k, copy=False).ndim
if trans1d != -1 and ndim < ndmin:
newobj = self._output_obj(newobj, ndim, ndmin, trans1d)
arrays.append(newobj)

objs.append(newobj)
if not scalar and isinstance(newobj, _core.ndarray):
arraytypes.append(newobj.dtype)

final_dtype = numpy.find_common_type(arraytypes, scalartypes)
final_dtype = numpy.result_type(*arrays, *scalars)
if final_dtype is not None:
for k in scalars:
objs[k] = objs[k].astype(final_dtype)
Expand Down
3 changes: 1 addition & 2 deletions cupyx/scipy/linalg/_special_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ def block_diag(*arrs):

shapes = tuple(a.shape for a in arrs)
shape = tuple(sum(x) for x in zip(*shapes))
dtype = cupy.find_common_type([a.dtype for a in arrs], [])
out = cupy.zeros(shape, dtype=dtype)
out = cupy.zeros(shape, dtype=cupy.result_type(*arrs))
r, c = 0, 0
for arr in arrs:
rr, cc = arr.shape
Expand Down
2 changes: 1 addition & 1 deletion cupyx/scipy/sparse/_sputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def upcast(*args):
if t is not None:
return t

upcast = cupy.find_common_type(args, [])
upcast = numpy.result_type(*args)

for t in supported_dtypes:
if cupy.can_cast(upcast, t):
Expand Down
2 changes: 1 addition & 1 deletion cupyx/scipy/sparse/linalg/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _get_dtype(operators, dtypes=None):
for obj in operators:
if obj is not None and hasattr(obj, 'dtype'):
dtypes.append(obj.dtype)
return cupy.find_common_type(dtypes, [])
return cupy.result_type(*dtypes)


class _SumLinearOperator(LinearOperator):
Expand Down
3 changes: 2 additions & 1 deletion tests/cupy_tests/indexing_tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_r_8(self, xp, dtype):
return xp.r_[a, b, c]

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
@testing.numpy_cupy_array_equal(
type_check=(numpy.lib.NumpyVersion(numpy.__version__) >= "1.25.0"))
def test_r_2(self, xp, dtype):
a = xp.array([1, 2, 3], dtype)
return xp.r_[a, 0, 0, a]
Expand Down

0 comments on commit f3e28cc

Please sign in to comment.