@@ -1028,22 +1028,22 @@ def from_dlpack(x, /, *, device=None, copy=None):
10281028 f" The argument of type {type(x)} does not implement "
10291029 " `__dlpack__` and `__dlpack_device__` methods."
10301030 )
1031- try :
1032- # device is converted to a dlpack_device if necessary
1033- dl_device = None
1034- if device:
1035- if isinstance (device, tuple ):
1036- dl_device = device
1037- if len (dl_device) != 2 :
1038- raise ValueError (
1039- " Argument `device` specified as a tuple must have length 2"
1040- )
1031+ # device is converted to a dlpack_device if necessary
1032+ dl_device = None
1033+ if device:
1034+ if isinstance (device, tuple ):
1035+ dl_device = device
1036+ if len (dl_device) != 2 :
1037+ raise ValueError (
1038+ " Argument `device` specified as a tuple must have length 2"
1039+ )
1040+ else :
1041+ if not isinstance (device, dpctl.SyclDevice):
1042+ d = Device.create_device(device).sycl_device
10411043 else :
1042- if not isinstance (device, dpctl.SyclDevice):
1043- d = Device.create_device(device).sycl_device
1044- else :
1045- d = device
1046- dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1044+ d = device
1045+ dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1046+ try :
10471047 dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
10481048 return from_dlpack_capsule(dlpack_capsule)
10491049 except TypeError :
@@ -1058,7 +1058,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
10581058 " Importing data via DLPack requires copying, but copy=False was provided"
10591059 )
10601060 if x_dldev == (device_CPU, 0 ) and dl_device[0 ] == device_OneAPI:
1061- host_blob = x
1061+ dlpack_capsule = dlpack_attr()
1062+ host_blob = from_dlpack_capsule(dlpack_capsule)
10621063 else :
10631064 raise BufferError(f" Can not import to requested device {dl_device}" )
10641065 return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
@@ -1074,7 +1075,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
10741075 raise BufferError(f" Can not import to requested device {dl_device}" )
10751076 x_dldev = dlpack_dev_attr()
10761077 if x_dldev == (device_CPU, 0 ):
1077- host_blob = x
1078+ dlpack_capsule = dlpack_attr()
1079+ host_blob = from_dlpack_capsule(dlpack_capsule)
10781080 else :
10791081 dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
10801082 host_blob = from_dlpack_capsule(dlpack_capsule)
0 commit comments