Skip to content

Commit 36014f8

Browse files
npolina4oleksandr-pavlyk
authored andcommitted
Fixed usage usm_ndarray for function iinfo, finfo, can_cast, and result_type
1 parent 101f635 commit 36014f8

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,19 +422,23 @@ def stack(arrays, axis=0):
422422
return res
423423

424424

425-
def can_cast(array_and_dtype_from, dtype_to, casting="safe"):
425+
def can_cast(from_, to, casting="safe"):
426426
"""
427427
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
428428
429429
Determines if one data type can be cast to another data type according \
430430
Type Promotion Rules rules.
431431
"""
432-
if not isinstance(dtype_to, dpt.dtype):
432+
if isinstance(to, dpt.usm_ndarray):
433433
raise TypeError("Expected dtype type.")
434434

435-
dtype_from = dpt.dtype(array_and_dtype_from)
435+
dtype_to = dpt.dtype(to)
436436

437-
_supported_dtype([dtype_to, dtype_from])
437+
dtype_from = (
438+
from_.dtype if isinstance(from_, dpt.usm_ndarray) else dpt.dtype(from_)
439+
)
440+
441+
_supported_dtype([dtype_from, dtype_to])
438442

439443
return np.can_cast(dtype_from, dtype_to, casting)
440444

@@ -447,7 +451,10 @@ def result_type(*arrays_and_dtypes):
447451
Returns the dtype that results from applying the Type Promotion Rules to \
448452
the arguments.
449453
"""
450-
dtypes = [dpt.dtype(X) for X in arrays_and_dtypes]
454+
dtypes = [
455+
X.dtype if isinstance(X, dpt.usm_ndarray) else dpt.dtype(X)
456+
for X in arrays_and_dtypes
457+
]
451458

452459
_supported_dtype(dtypes)
453460

@@ -460,6 +467,8 @@ def iinfo(type):
460467
461468
Returns machine limits for integer data types.
462469
"""
470+
if isinstance(type, dpt.usm_ndarray):
471+
raise TypeError("Expected dtype type, get {to}.")
463472
_supported_dtype([dpt.dtype(type)])
464473
return np.iinfo(type)
465474

@@ -470,11 +479,14 @@ def finfo(type):
470479
471480
Returns machine limits for float data types.
472481
"""
482+
if isinstance(type, dpt.usm_ndarray):
483+
raise TypeError("Expected dtype type, get {to}.")
473484
_supported_dtype([dpt.dtype(type)])
474485
return np.finfo(type)
475486

476487

477488
def _supported_dtype(dtypes):
478-
if not all(dtype.char in "?bBhHiIlLqQefdFD" for dtype in dtypes):
479-
raise ValueError("Unsupported dtype encountered.")
489+
for dtype in dtypes:
490+
if dtype.char not in "?bBhHiIlLqQefdFD":
491+
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
480492
return True

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,8 +1010,10 @@ def test_can_cast():
10101010
# incorrect input
10111011
X = dpt.ones((2, 2), dtype=dpt.int64, sycl_queue=q)
10121012
pytest.raises(TypeError, dpt.can_cast, X, 1)
1013+
pytest.raises(TypeError, dpt.can_cast, X, X)
10131014
X_np = np.ones((2, 2), dtype=np.int64)
10141015

1016+
assert dpt.can_cast(X, "float32") == np.can_cast(X_np, "float32")
10151017
assert dpt.can_cast(X, dpt.int32) == np.can_cast(X_np, np.int32)
10161018
assert dpt.can_cast(X, dpt.int64) == np.can_cast(X_np, np.int64)
10171019

@@ -1022,7 +1024,7 @@ def test_result_type():
10221024
except dpctl.SyclQueueCreationError:
10231025
pytest.skip("Queue could not be created")
10241026

1025-
X = [dpt.ones((2), dtype=dpt.int64, sycl_queue=q), dpt.int32, dpt.float16]
1026-
X_np = [np.ones((2), dtype=np.int64), dpt.int32, dpt.float16]
1027+
X = [dpt.ones((2), dtype=dpt.int64, sycl_queue=q), dpt.int32, "float16"]
1028+
X_np = [np.ones((2), dtype=np.int64), np.int32, "float16"]
10271029

10281030
assert dpt.result_type(*X) == np.result_type(*X_np)

0 commit comments

Comments
 (0)