Skip to content

Commit 6640e9e

Browse files
Merge pull request #1296 from vlad-perevezentsev/dlpack_support
Add dlpack support
2 parents a516f1c + 103eb88 commit 6640e9e

File tree

4 files changed

+88
-5
lines changed

4 files changed

+88
-5
lines changed

dpnp/dpnp_array.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __bool__(self):
140140
return self._array_obj.__bool__()
141141

142142
# '__class__',
143-
143+
144144
def __complex__(self):
145145
return self._array_obj.__complex__()
146146

@@ -153,6 +153,12 @@ def __complex__(self):
153153
# '__divmod__',
154154
# '__doc__',
155155

156+
def __dlpack__(self, stream=None):
157+
return self._array_obj.__dlpack__(stream=stream)
158+
159+
def __dlpack_device__(self):
160+
return self._array_obj.__dlpack_device__()
161+
156162
def __eq__(self, other):
157163
return dpnp.equal(self, other)
158164

@@ -190,7 +196,7 @@ def __gt__(self, other):
190196
# '__imatmul__',
191197
# '__imod__',
192198
# '__imul__',
193-
199+
194200
def __index__(self):
195201
return self._array_obj.__index__()
196202

@@ -315,6 +321,16 @@ def __truediv__(self, other):
315321

316322
# '__xor__',
317323

324+
@staticmethod
325+
def _create_from_usm_ndarray(usm_ary : dpt.usm_ndarray):
326+
if not isinstance(usm_ary, dpt.usm_ndarray):
327+
raise TypeError(
328+
f"Expected dpctl.tensor.usm_ndarray, got {type(usm_ary)}"
329+
)
330+
res = dpnp_array.__new__(dpnp_array)
331+
res._array_obj = usm_ary
332+
return res
333+
318334
def all(self, axis=None, out=None, keepdims=False):
319335
"""
320336
Returns True if all elements evaluate to True.

dpnp/dpnp_iface.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"default_float_type",
6565
"dpnp_queue_initialize",
6666
"dpnp_queue_is_cpu",
67+
"from_dlpack",
6768
"get_dpnp_descriptor",
6869
"get_include",
6970
"get_normalized_queue_device"
@@ -222,6 +223,31 @@ def default_float_type(device=None, sycl_queue=None):
222223
return map_dtype_to_device(float64, _sycl_queue.sycl_device)
223224

224225

226+
def from_dlpack(obj, /):
227+
"""
228+
Create a dpnp array from a Python object implementing the ``__dlpack__``
229+
protocol.
230+
231+
See https://dmlc.github.io/dlpack/latest/ for more details.
232+
233+
Parameters
234+
----------
235+
obj : object
236+
A Python object representing an array that implements the ``__dlpack__``
237+
and ``__dlpack_device__`` methods.
238+
239+
Returns
240+
-------
241+
out : dpnp_array
242+
Returns a new dpnp array containing the data from another array
243+
(obj) with the ``__dlpack__`` method on the same device as object.
244+
245+
"""
246+
247+
usm_ary = dpt.from_dlpack(obj)
248+
return dpnp_array._create_from_usm_ndarray(usm_ary)
249+
250+
225251
def get_dpnp_descriptor(ext_obj,
226252
copy_when_strides=True,
227253
copy_when_nondefault_queue=True,

tests/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_all_dtypes(no_bool=False,
3232
dtypes.append(dpnp.complex64)
3333
if dev.has_aspect_fp64:
3434
dtypes.append(dpnp.complex128)
35-
35+
3636
# add None value to validate a default dtype
3737
if not no_none:
3838
dtypes.append(None)

tests/test_sycl_queue.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import pytest
2+
from .helper import get_all_dtypes
23

34
import dpnp
45
import dpctl
56
import numpy
67

8+
from numpy.testing import (
9+
assert_array_equal
10+
)
11+
712

813
list_of_backend_str = [
914
"host",
@@ -155,7 +160,7 @@ def test_array_creation_like(func, kwargs, device_x, device_y):
155160

156161
dpnp_kwargs = dict(kwargs)
157162
dpnp_kwargs['device'] = device_y
158-
163+
159164
y = getattr(dpnp, func)(x, **dpnp_kwargs)
160165
numpy.testing.assert_array_equal(y_orig, y)
161166
assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
@@ -647,7 +652,7 @@ def test_eig(device):
647652
dpnp_val_queue = dpnp_val.get_array().sycl_queue
648653
dpnp_vec_queue = dpnp_vec.get_array().sycl_queue
649654

650-
# compare queue and device
655+
# compare queue and device
651656
assert_sycl_queue_equal(dpnp_val_queue, expected_queue)
652657
assert_sycl_queue_equal(dpnp_vec_queue, expected_queue)
653658

@@ -816,3 +821,39 @@ def test_array_copy(device, func, device_param, queue_param):
816821
result = dpnp.array(dpnp_data, **kwargs)
817822

818823
assert_sycl_queue_equal(result.sycl_queue, dpnp_data.sycl_queue)
824+
825+
826+
@pytest.mark.parametrize("device",
827+
valid_devices,
828+
ids=[device.filter_string for device in valid_devices])
829+
#TODO need to delete no_bool=True when use dlpack > 0.7 version
830+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True))
831+
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
832+
def test_from_dlpack(arr_dtype, shape, device):
833+
X = dpnp.empty(shape=shape, dtype=arr_dtype, device=device)
834+
Y = dpnp.from_dlpack(X)
835+
assert_array_equal(X, Y)
836+
assert X.__dlpack_device__() == Y.__dlpack_device__()
837+
assert X.sycl_device == Y.sycl_device
838+
assert X.sycl_context == Y.sycl_context
839+
assert X.usm_type == Y.usm_type
840+
if Y.ndim:
841+
V = Y[::-1]
842+
W = dpnp.from_dlpack(V)
843+
assert V.strides == W.strides
844+
845+
846+
@pytest.mark.parametrize("device",
847+
valid_devices,
848+
ids=[device.filter_string for device in valid_devices])
849+
#TODO need to delete no_bool=True when use dlpack > 0.7 version
850+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True))
851+
def test_from_dlpack_with_dpt(arr_dtype, device):
852+
X = dpctl.tensor.empty((64,), dtype=arr_dtype, device=device)
853+
Y = dpnp.from_dlpack(X)
854+
assert_array_equal(X, Y)
855+
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)
856+
assert X.__dlpack_device__() == Y.__dlpack_device__()
857+
assert X.sycl_device == Y.sycl_device
858+
assert X.sycl_context == Y.sycl_context
859+
assert X.usm_type == Y.usm_type

0 commit comments

Comments
 (0)