Skip to content

Commit 9afae01

Browse files
committed
Adds check that max_version is a 2-tuple to __dlpack__
1 parent c54a90a commit 9afae01

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1127,8 +1127,14 @@ cdef class usm_ndarray:
11271127
stream.submit_barrier(dependent_events=[ev])
11281128
return _caps
11291129
else:
1130+
if not isinstance(max_version, tuple) or len(max_version) != 2:
1131+
raise TypeError(
1132+
"`__dlpack__` expects `max_version` to be a "
1133+
"2-tuple of integers `(major, minor)`, instead "
1134+
f"got {type(max_version)}"
1135+
)
11301136
dpctl_dlpack_version = get_build_dlpack_version()
1131-
if max_version >= dpctl_dlpack_version or max_version[0] == dpctl_dlpack_version[0]:
1137+
if max_version[0] >= dpctl_dlpack_version[0]:
11321138
# DLManagedTensorVersioned path
11331139
# TODO: add logic for targeting a device
11341140
if dl_device is not None:

0 commit comments

Comments
 (0)