Skip to content

Commit a0e32c0

Browse files
committed
Adopt dpnp to DLPack v1.0
1 parent 4a23239 commit a0e32c0

File tree

2 files changed

+74
-11
lines changed

2 files changed

+74
-11
lines changed

dpnp/dpnp_array.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,27 +184,61 @@ def __copy__(self):
184184
# '__divmod__',
185185
# '__doc__',
186186

187-
def __dlpack__(self, stream=None):
187+
def __dlpack__(
188+
self, *, stream=None, max_version=None, dl_device=None, copy=None
189+
):
188190
"""
189191
Produces DLPack capsule.
190192
191193
Parameters
192194
----------
193195
stream : {:class:`dpctl.SyclQueue`, None}, optional
194-
Execution queue to synchronize with. If ``None``,
195-
synchronization is not performed.
196+
Execution queue to synchronize with. If ``None``, synchronization
197+
is not performed.
198+
Default: ``None``.
199+
max_version {tuple of ints, None}, optional
200+
The maximum DLPack version the consumer (caller of ``__dlpack__``)
201+
supports. As ``__dlpack__`` may not always return a DLPack capsule
202+
with version `max_version`, the consumer must verify the version
203+
even if this argument is passed.
204+
Default: ``None``.
205+
dl_device {tuple, None}, optional:
206+
The device the returned DLPack capsule will be placed on. The
207+
device must be a 2-tuple matching the format of
208+
``__dlpack_device__`` method, an integer enumerator representing
209+
the device type followed by an integer representing the index of
210+
the device.
211+
Default: ``None``.
212+
copy {bool, None}, optional:
213+
Boolean indicating whether or not to copy the input.
214+
215+
* If `copy` is ``True``, the input will always be copied.
216+
* If ``False``, a ``BufferError`` will be raised if a copy is
217+
deemed necessary.
218+
* If ``None``, a copy will be made only if deemed necessary,
219+
otherwise, the existing memory buffer will be reused.
220+
221+
Default: ``None``.
196222
197223
Raises
198224
------
199-
MemoryError
225+
MemoryError:
200226
when host memory can not be allocated.
201-
DLPackCreationError
202-
when array is allocated on a partitioned
203-
SYCL device, or with a non-default context.
227+
DLPackCreationError:
228+
when array is allocated on a partitioned SYCL device, or with
229+
a non-default context.
230+
BufferError:
231+
when a copy is deemed necessary but `copy` is ``False`` or when
232+
the provided `dl_device` cannot be handled.
204233
205234
"""
206235

207-
return self._array_obj.__dlpack__(stream=stream)
236+
return self._array_obj.__dlpack__(
237+
stream=stream,
238+
max_version=max_version,
239+
dl_device=dl_device,
240+
copy=copy,
241+
)
208242

209243
def __dlpack_device__(self):
210244
"""

dpnp/dpnp_iface.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def default_float_type(device=None, sycl_queue=None):
464464
return map_dtype_to_device(float64, _sycl_queue.sycl_device)
465465

466466

467-
def from_dlpack(obj, /):
467+
def from_dlpack(obj, /, *, device=None, copy=None):
468468
"""
469469
Create a dpnp array from a Python object implementing the ``__dlpack__``
470470
protocol.
@@ -476,17 +476,46 @@ def from_dlpack(obj, /):
476476
obj : object
477477
A Python object representing an array that implements the ``__dlpack__``
478478
and ``__dlpack_device__`` methods.
479+
device : {:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`,
480+
:class:`dpctl.tensor.Device`, tuple, None}, optional
481+
Array API concept of a device where the output array is to be placed.
482+
``device`` can be ``None``, an oneAPI filter selector string,
483+
an instance of :class:`dpctl.SyclDevice` corresponding to
484+
a non-partitioned SYCL device, an instance of :class:`dpctl.SyclQueue`,
485+
a :class:`dpctl.tensor.Device` object returned by
486+
:attr:`dpctl.tensor.usm_ndarray.device`, or a 2-tuple matching
487+
the format of the output of the ``__dlpack_device__`` method,
488+
an integer enumerator representing the device type followed by
489+
an integer representing the index of the device.
490+
Default: ``None``.
491+
copy {bool, None}, optional
492+
Boolean indicating whether or not to copy the input.
493+
494+
* If `copy``is ``True``, the input will always be copied.
495+
* If ``False``, a ``BufferError`` will be raised if a copy is deemed
496+
necessary.
497+
* If ``None``, a copy will be made only if deemed necessary, otherwise,
498+
the existing memory buffer will be reused.
499+
500+
Default: ``None``.
479501
480502
Returns
481503
-------
482504
out : dpnp_array
483505
Returns a new dpnp array containing the data from another array
484506
(obj) with the ``__dlpack__`` method on the same device as object.
485507
508+
Raises
509+
------
510+
TypeError:
511+
if `obj` does not implement ``__dlpack__`` method
512+
ValueError:
513+
if the input array resides on an unsupported device
514+
486515
"""
487516

488-
usm_ary = dpt.from_dlpack(obj)
489-
return dpnp_array._create_from_usm_ndarray(usm_ary)
517+
usm_res = dpt.from_dlpack(obj, device=device, copy=copy)
518+
return dpnp_array._create_from_usm_ndarray(usm_res)
490519

491520

492521
def get_dpnp_descriptor(

0 commit comments

Comments
 (0)