Skip to content

Commit ee8f15a

Browse files
authored
Merge pull request #1355 from IntelPython/fix_asarray_sequences
Allow asarray to work on sequences of dpnp_array
2 parents c1c40e3 + c7ae46e commit ee8f15a

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

dpnp/dpnp_container.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,34 @@ def asarray(x1,
8686
usm_type=None,
8787
sycl_queue=None):
8888
"""Converts `x1` to `dpnp_array`."""
89-
if isinstance(x1, dpnp_array):
90-
x1_obj = x1.get_array()
91-
else:
92-
x1_obj = x1
89+
dpu.validate_usm_type(usm_type, allow_none=True)
9390

94-
sycl_queue_normalized = dpnp.get_normalized_queue_device(x1_obj, device=device, sycl_queue=sycl_queue)
9591
if order is None:
9692
order = 'C'
9793

9894
"""Converts incoming 'x1' object to 'dpnp_array'."""
99-
array_obj = dpt.asarray(x1_obj,
100-
dtype=dtype,
101-
copy=copy,
102-
order=order,
103-
usm_type=usm_type,
104-
sycl_queue=sycl_queue_normalized)
95+
if isinstance(x1, (list, tuple, range)):
96+
array_obj = dpt.asarray(x1,
97+
dtype=dtype,
98+
copy=copy,
99+
order=order,
100+
device=device,
101+
usm_type=usm_type,
102+
sycl_queue=sycl_queue)
103+
else:
104+
if isinstance(x1, dpnp_array):
105+
x1_obj = x1.get_array()
106+
else:
107+
x1_obj = x1
108+
109+
sycl_queue_normalized = dpnp.get_normalized_queue_device(x1_obj, device=device, sycl_queue=sycl_queue)
110+
111+
array_obj = dpt.asarray(x1_obj,
112+
dtype=dtype,
113+
copy=copy,
114+
order=order,
115+
usm_type=usm_type,
116+
sycl_queue=sycl_queue_normalized)
105117
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
106118

107119

tests/test_sycl_queue.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,3 +945,15 @@ def test_broadcast_to(device):
945945
x = dpnp.arange(5, device=device)
946946
y = dpnp.broadcast_to(x, (3, 5))
947947
assert_sycl_queue_equal(x.sycl_queue, y.sycl_queue)
948+
949+
950+
@pytest.mark.parametrize("device_x",
951+
valid_devices,
952+
ids=[device.filter_string for device in valid_devices])
953+
@pytest.mark.parametrize("device_y",
954+
valid_devices,
955+
ids=[device.filter_string for device in valid_devices])
956+
def test_asarray(device_x, device_y):
957+
x = dpnp.array([1, 2, 3], device=device_x)
958+
y = dpnp.asarray([x, x, x], device=device_y)
959+
assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)

0 commit comments

Comments
 (0)