Skip to content

Commit 3e3ab03

Browse files
Reuse get parent device ordinal id routine (#1672)
* Reused get_parent_device_ordinal_id routine * test_legacy_dlpack_capsule uses 4 kinds of dtype Added test to use non-default copy keyword, and non-default device keyword argument.
1 parent 1205b0b commit 3e3ab03

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ cdef int get_array_dlpack_device_id(
211211
cdef c_dpctl.SyclQueue ary_sycl_queue
212212
cdef c_dpctl.SyclDevice ary_sycl_device
213213
cdef DPCTLSyclDeviceRef pDRef = NULL
214-
cdef DPCTLSyclDeviceRef tDRef = NULL
215214
cdef int device_id = -1
216215

217216
ary_sycl_queue = usm_ary.get_sycl_queue()
@@ -228,26 +227,18 @@ cdef int get_array_dlpack_device_id(
228227
"on non-partitioned SYCL devices on platforms where "
229228
"default_context oneAPI extension is not supported."
230229
)
230+
device_id = ary_sycl_device.get_overall_ordinal()
231231
else:
232232
if not usm_ary.sycl_context == default_context:
233233
raise DLPackCreationError(
234234
"to_dlpack_capsule: DLPack can only export arrays based on USM "
235235
"allocations bound to a default platform SYCL context"
236236
)
237-
# Find the unpartitioned parent of the allocation device
238-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
239-
if pDRef is not NULL:
240-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
241-
while tDRef is not NULL:
242-
DPCTLDevice_Delete(pDRef)
243-
pDRef = tDRef
244-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
245-
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
237+
device_id = get_parent_device_ordinal_id(ary_sycl_device)
246238

247-
device_id = ary_sycl_device.get_overall_ordinal()
248239
if device_id < 0:
249240
raise DLPackCreationError(
250-
"to_dlpack_capsule: failed to determine device_id"
241+
"get_array_dlpack_device_id: failed to determine device_id"
251242
)
252243

253244
return device_id
@@ -295,11 +286,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
295286

296287
device_id = get_array_dlpack_device_id(usm_ary)
297288

298-
if device_id < 0:
299-
raise DLPackCreationError(
300-
"to_dlpack_capsule: failed to determine device_id"
301-
)
302-
303289
dlm_tensor = <DLManagedTensor *> stdlib.malloc(
304290
sizeof(DLManagedTensor))
305291
if dlm_tensor is NULL:

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,22 @@ def test_legacy_dlpack_capsule():
261261
del cap
262262
assert x._pointer == y._pointer
263263

264+
x = dpt.arange(100, dtype="u4")
264265
x2 = dpt.reshape(x, (10, 10)).mT
265266
cap = x2.__dlpack__(max_version=legacy_ver)
266267
y = _dlp.from_dlpack_capsule(cap)
267268
del cap
268269
assert x2._pointer == y._pointer
269270
del x2
270271

272+
x = dpt.arange(100, dtype="f4")
271273
x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F")
272274
cap = x2.__dlpack__(max_version=legacy_ver)
273275
y = _dlp.from_dlpack_capsule(cap)
274276
del cap
275277
assert x2._pointer == y._pointer
276278

279+
x = dpt.arange(100, dtype="c8")
277280
x3 = x[::-2]
278281
cap = x3.__dlpack__(max_version=legacy_ver)
279282
y = _dlp.from_dlpack_capsule(cap)
@@ -321,3 +324,16 @@ def test_versioned_dlpack_capsule():
321324
y = _dlp.from_dlpack_versioned_capsule(cap)
322325
assert x._pointer != y._pointer
323326
assert not y.flags.writable
327+
328+
329+
def test_from_dlpack_kwargs():
330+
try:
331+
x = dpt.arange(100, dtype="i4")
332+
except dpctl.SyclDeviceCreationError:
333+
pytest.skip("No default device available")
334+
335+
y = dpt.from_dlpack(x, copy=True)
336+
assert x._pointer != y._pointer
337+
338+
z = dpt.from_dlpack(x, device=x.sycl_device)
339+
assert z._pointer == x._pointer

0 commit comments

Comments
 (0)