Skip to content

Commit 6f52954

Browse files
committed
Fixed gh-1272
1 parent 351c6a6 commit 6f52954

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

dpnp/dpnp_iface_arraycreation.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from dpnp.dpnp_utils import *
4949

5050
import dpnp.dpnp_container as dpnp_container
51+
import dpctl.tensor as dpt
5152

5253

5354
__all__ = [
@@ -530,7 +531,7 @@ def empty_like(x1,
530531
531532
Limitations
532533
-----------
533-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
534+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
534535
Parameter ``order`` is supported with values ``"C"`` or ``"F"``.
535536
Parameter ``subok`` is supported only with default value ``False``.
536537
Otherwise the function will be executed sequentially on CPU.
@@ -552,7 +553,7 @@ def empty_like(x1,
552553
553554
"""
554555

555-
if not isinstance(x1, dpnp.ndarray):
556+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
556557
pass
557558
elif order not in ('C', 'c', 'F', 'f', None):
558559
pass
@@ -762,7 +763,7 @@ def full_like(x1,
762763
763764
Limitations
764765
-----------
765-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
766+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
766767
Parameter ``order`` is supported only with values ``"C"`` and ``"F"``.
767768
Parameter ``subok`` is supported only with default value ``False``.
768769
Otherwise the function will be executed sequentially on CPU.
@@ -783,7 +784,7 @@ def full_like(x1,
783784
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
784785
785786
"""
786-
if not isinstance(x1, dpnp.ndarray):
787+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
787788
pass
788789
elif order not in ('C', 'c', 'F', 'f', None):
789790
pass
@@ -1189,7 +1190,7 @@ def ones_like(x1,
11891190
11901191
Limitations
11911192
-----------
1192-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
1193+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
11931194
Parameter ``order`` is supported with values ``"C"`` or ``"F"``.
11941195
Parameter ``subok`` is supported only with default value ``False``.
11951196
Otherwise the function will be executed sequentially on CPU.
@@ -1211,7 +1212,7 @@ def ones_like(x1,
12111212
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
12121213
12131214
"""
1214-
if not isinstance(x1, dpnp.ndarray):
1215+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
12151216
pass
12161217
elif order not in ('C', 'c', 'F', 'f', None):
12171218
pass
@@ -1502,7 +1503,7 @@ def zeros_like(x1,
15021503
15031504
Limitations
15041505
-----------
1505-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
1506+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
15061507
Parameter ``order`` is supported with values ``"C"`` or ``"F"``.
15071508
Parameter ``subok`` is supported only with default value ``False``.
15081509
Otherwise the function will be executed sequentially on CPU.
@@ -1523,8 +1524,8 @@ def zeros_like(x1,
15231524
>>> [i for i in np.zeros_like(x)]
15241525
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
15251526
1526-
"""
1527-
if not isinstance(x1, dpnp.ndarray):
1527+
"""
1528+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
15281529
pass
15291530
elif order not in ('C', 'c', 'F', 'f', None):
15301531
pass

tests/test_arraycreation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,26 @@ def test_ones_like(array, dtype, order):
485485
a = numpy.array(array)
486486
ia = dpnp.array(array)
487487
assert_array_equal(func(numpy, a), func(dpnp, ia))
488+
489+
490+
@pytest.mark.parametrize(
491+
"func, args",
492+
[
493+
pytest.param("full_like",
494+
['x0', '4']),
495+
pytest.param("zeros_like",
496+
['x0']),
497+
pytest.param("ones_like",
498+
['x0']),
499+
pytest.param("empty_like",
500+
['x0']),
501+
])
502+
def test_dpctl_tensor_input(func, args):
503+
x0 = dpt.reshape(dpt.arange(9), (3,3))
504+
new_args = [eval(val, {'x0' : x0}) for val in args]
505+
X = getattr(dpt, func)(*new_args)
506+
Y = getattr(dpnp, func)(*new_args)
507+
if func is 'empty_like':
508+
assert X.shape == Y.shape
509+
else:
510+
assert_array_equal(X, Y)

0 commit comments

Comments
 (0)