Skip to content

Removed temporary workaround from dpnp.get_normalized_queue_device() #1605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,37 +366,45 @@ def get_normalized_queue_device(obj=None, device=None, sycl_queue=None):
Utility to process complementary keyword arguments 'device' and 'sycl_queue'
in subsequent calls of functions from `dpctl.tensor` module.

If both arguments 'device' and 'sycl_queue' have default value `None`
If both arguments 'device' and 'sycl_queue' have default value ``None``
and 'obj' has `sycl_queue` attribute, it assumes that Compute Follows Data
approach has to be applied and so the resulting SYCL queue will be normalized
based on the queue value from 'obj'.

Args:
obj (optional): A python object. Can be an instance of `dpnp_array`,
`dpctl.tensor.usm_ndarray`, an object representing SYCL USM allocation
and implementing `__sycl_usm_array_interface__` protocol,
an instance of `numpy.ndarray`, an object supporting Python buffer protocol,
a Python scalar, or a (possibly nested) sequence of Python scalars.
sycl_queue (:class:`dpctl.SyclQueue`, optional):
explicitly indicates where USM allocation is done
and the population code (if any) is executed.
Value `None` is interpreted as get the SYCL queue
from `obj` parameter if not None, from `device` keyword,
or use default queue.
Default: None
device (string, :class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue,
:class:`dpctl.tensor.Device`, optional):
array-API keyword indicating non-partitioned SYCL device
where array is allocated.
Parameters
----------
obj : object, optional
A python object. Can be an instance of `dpnp_array`,
`dpctl.tensor.usm_ndarray`, an object representing SYCL USM allocation
and implementing `__sycl_usm_array_interface__` protocol, an instance
of `numpy.ndarray`, an object supporting Python buffer protocol,
a Python scalar, or a (possibly nested) sequence of Python scalars.
sycl_queue : class:`dpctl.SyclQueue`, optional
A queue which explicitly indicates where USM allocation is done
and the population code (if any) is executed.
Value ``None`` is interpreted as to get the SYCL queue from either
`obj` parameter if not ``None`` or from `device` keyword,
or to use default queue.
device : {string, :class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue,
:class:`dpctl.tensor.Device`}, optional
An array-API keyword indicating non-partitioned SYCL device
where array is allocated.

Returns
:class:`dpctl.SyclQueue` object normalized by `normalize_queue_device` call
-------
sycl_queue: dpctl.SyclQueue
A :class:`dpctl.SyclQueue` object normalized by `normalize_queue_device` call
of `dpctl.tensor` module invoked with 'device' and 'sycl_queue' values.
If both incoming 'device' and 'sycl_queue' are None and 'obj' has `sycl_queue` attribute,
the normalization will be performed for 'obj.sycl_queue' value.
Raises:
TypeError: if argument is not of the expected type, or keywords
imply incompatible queues.

Raises
------
TypeError
If argument is not of the expected type, or keywords imply incompatible queues.

"""

if (
device is None
and sycl_queue is None
Expand All @@ -405,12 +413,9 @@ def get_normalized_queue_device(obj=None, device=None, sycl_queue=None):
):
sycl_queue = obj.sycl_queue

# TODO: remove check dpt._device has attribute 'normalize_queue_device'
if hasattr(dpt._device, "normalize_queue_device"):
return dpt._device.normalize_queue_device(
sycl_queue=sycl_queue, device=device
)
return sycl_queue
return dpt._device.normalize_queue_device(
sycl_queue=sycl_queue, device=device
)


def get_usm_ndarray(a):
Expand Down