Skip to content

Improve array API conformity #1110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,12 @@ def asarray(


def empty(
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
shape,
dtype=None,
order="C",
device=None,
usm_type="device",
sycl_queue=None,
):
"""
Creates `usm_ndarray` from uninitialized USM allocation.
Expand Down Expand Up @@ -509,7 +514,7 @@ def empty(
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand Down Expand Up @@ -650,7 +655,12 @@ def arange(


def zeros(
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
shape,
dtype=None,
order="C",
device=None,
usm_type="device",
sycl_queue=None,
):
"""
Creates `usm_ndarray` with zero elements.
Expand Down Expand Up @@ -687,7 +697,7 @@ def zeros(
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand All @@ -698,7 +708,12 @@ def zeros(


def ones(
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
shape,
dtype=None,
order="C",
device=None,
usm_type="device",
sycl_queue=None,
):
"""
Creates `usm_ndarray` with elements of one.
Expand Down Expand Up @@ -734,7 +749,7 @@ def ones(
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand All @@ -746,7 +761,7 @@ def ones(


def full(
sh,
shape,
fill_value,
dtype=None,
order="C",
Expand Down Expand Up @@ -805,14 +820,14 @@ def full(
usm_type=usm_type,
sycl_queue=sycl_queue,
)
return dpt.copy(dpt.broadcast_to(X, sh), order=order)
return dpt.copy(dpt.broadcast_to(X, shape), order=order)

sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
usm_type = usm_type if usm_type is not None else "device"
fill_value_type = type(fill_value)
dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand Down Expand Up @@ -872,11 +887,11 @@ def empty_like(
if device is None and sycl_queue is None:
device = x.device
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
sh = x.shape
shape = x.shape
dtype = dpt.dtype(dtype)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand Down
78 changes: 50 additions & 28 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@
)


class finfo_object(np.finfo):
"""
numpy.finfo subclass which returns Python floating-point scalars for
eps, max, min, and smallest_normal.
"""

def __init__(self, dtype):
_supported_dtype([dpt.dtype(dtype)])
super().__init__()

self.eps = float(self.eps)
self.max = float(self.max)
self.min = float(self.min)

@property
def smallest_normal(self):
return float(super().smallest_normal)

@property
def tiny(self):
return float(super().tiny)


def _broadcast_strides(X_shape, X_strides, res_ndim):
"""
Broadcasts strides to match the given dimensions;
Expand Down Expand Up @@ -122,46 +145,46 @@ def permute_dims(X, axes):
)


def expand_dims(X, axes):
def expand_dims(X, axis):
"""
expand_dims(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
expand_dims(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray

Expands the shape of an array by inserting a new axis (dimension)
of size one at the position specified by axes; returns a view, if possible,
of size one at the position specified by axis; returns a view, if possible,
a copy otherwise with the number of dimensions increased.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
if not isinstance(axes, (tuple, list)):
axes = (axes,)
if not isinstance(axis, (tuple, list)):
axis = (axis,)

out_ndim = len(axes) + X.ndim
axes = normalize_axis_tuple(axes, out_ndim)
out_ndim = len(axis) + X.ndim
axis = normalize_axis_tuple(axis, out_ndim)

shape_it = iter(X.shape)
shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim))
shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim))

return dpt.reshape(X, shape)


def squeeze(X, axes=None):
def squeeze(X, axis=None):
"""
squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
squeeze(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray

Removes singleton dimensions (axes) from X; returns a view, if possible,
Removes singleton dimensions (axis) from X; returns a view, if possible,
a copy otherwise, but with all or a subset of the dimensions
of length 1 removed.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
X_shape = X.shape
if axes is not None:
if not isinstance(axes, (tuple, list)):
axes = (axes,)
axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1)
if axis is not None:
if not isinstance(axis, (tuple, list)):
axis = (axis,)
axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
new_shape = []
for i, x in enumerate(X_shape):
if i not in axes:
if i not in axis:
new_shape.append(x)
else:
if x != 1:
Expand Down Expand Up @@ -222,9 +245,9 @@ def broadcast_arrays(*args):
return [broadcast_to(X, shape) for X in args]


def flip(X, axes=None):
def flip(X, axis=None):
"""
flip(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
flip(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray

Reverses the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered;
Expand All @@ -233,20 +256,20 @@ def flip(X, axes=None):
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
X_ndim = X.ndim
if axes is None:
if axis is None:
indexer = (np.s_[::-1],) * X_ndim
else:
axes = normalize_axis_tuple(axes, X_ndim)
axis = normalize_axis_tuple(axis, X_ndim)
indexer = tuple(
np.s_[::-1] if i in axes else np.s_[:] for i in range(X.ndim)
np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim)
)
return X[indexer]


def roll(X, shift, axes=None):
def roll(X, shift, axis=None):
"""
roll(X: usm_ndarray, shift: int or tuple or list,\
axes: int or tuple or list) -> usm_ndarray
axis: int or tuple or list) -> usm_ndarray

Rolls array elements along a specified axis.
Array elements that roll beyond the last position are re-introduced
Expand All @@ -257,7 +280,7 @@ def roll(X, shift, axes=None):
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
if axes is None:
if axis is None:
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
)
Expand All @@ -266,8 +289,8 @@ def roll(X, shift, axes=None):
)
hev.wait()
return res
axes = normalize_axis_tuple(axes, X.ndim, allow_duplicate=True)
broadcasted = np.broadcast(shift, axes)
axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True)
broadcasted = np.broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(X.ndim)}
Expand Down Expand Up @@ -495,8 +518,7 @@ def finfo(dtype):
"""
if isinstance(dtype, dpt.usm_ndarray):
raise TypeError("Expected dtype type, got {to}.")
_supported_dtype([dpt.dtype(dtype)])
return np.finfo(dtype)
return finfo_object(dtype)


def _supported_dtype(dtypes):
Expand Down
28 changes: 14 additions & 14 deletions dpctl/tensor/_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
return new_sts if valid else None


def reshape(X, newshape, order="C", copy=None):
def reshape(X, shape, order="C", copy=None):
"""
reshape(X: usm_ndarray, newshape: tuple, order="C") -> usm_ndarray
reshape(X: usm_ndarray, shape: tuple, order="C") -> usm_ndarray

Reshapes given usm_ndarray into new shape. Returns a view, if possible,
a copy otherwise. Memory layout of the copy is controlled by order keyword.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError
if not isinstance(newshape, (list, tuple)):
newshape = (newshape,)
if not isinstance(shape, (list, tuple)):
shape = (shape,)
if order in "cfCF":
order = order.upper()
else:
Expand All @@ -97,9 +97,9 @@ def reshape(X, newshape, order="C", copy=None):
f"Keyword 'copy' not recognized. Expecting True, False, "
f"or None, got {copy}"
)
newshape = [operator.index(d) for d in newshape]
shape = [operator.index(d) for d in shape]
negative_ones_count = 0
for nshi in newshape:
for nshi in shape:
if nshi == -1:
negative_ones_count = negative_ones_count + 1
if (nshi < -1) or negative_ones_count > 1:
Expand All @@ -108,14 +108,14 @@ def reshape(X, newshape, order="C", copy=None):
"value which can only be -1"
)
if negative_ones_count:
v = X.size // (-np.prod(newshape))
newshape = [v if d == -1 else d for d in newshape]
if X.size != np.prod(newshape):
raise ValueError(f"Can not reshape into {newshape}")
v = X.size // (-np.prod(shape))
shape = [v if d == -1 else d for d in shape]
if X.size != np.prod(shape):
raise ValueError(f"Can not reshape into {shape}")
if X.size:
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
newsts = reshaped_strides(X.shape, X.strides, shape, order=order)
else:
newsts = (1,) * len(newshape)
newsts = (1,) * len(shape)
copy_required = newsts is None
if copy_required and (copy is False):
raise ValueError(
Expand All @@ -141,11 +141,11 @@ def reshape(X, newshape, order="C", copy=None):
flat_res[i], X[np.unravel_index(i, X.shape, order=order)]
)
return dpt.usm_ndarray(
tuple(newshape), dtype=X.dtype, buffer=flat_res, order=order
tuple(shape), dtype=X.dtype, buffer=flat_res, order=order
)
# can form a view
return dpt.usm_ndarray(
newshape,
shape,
dtype=X.dtype,
buffer=X,
strides=tuple(newsts),
Expand Down
10 changes: 5 additions & 5 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def test_incompatible_shapes_raise_valueerror(shapes):
assert_broadcast_arrays_raise(input_shapes[::-1])


def test_flip_axes_incorrect():
def test_flip_axis_incorrect():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
Expand All @@ -492,10 +492,10 @@ def test_flip_axes_incorrect():
X_np = np.ones((4, 4))
X = dpt.asarray(X_np, sycl_queue=q)

pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axes=1)
pytest.raises(np.AxisError, dpt.flip, X, axes=2)
pytest.raises(np.AxisError, dpt.flip, X, axes=-3)
pytest.raises(np.AxisError, dpt.flip, X, axes=(0, 3))
pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axis=1)
pytest.raises(np.AxisError, dpt.flip, X, axis=2)
pytest.raises(np.AxisError, dpt.flip, X, axis=-3)
pytest.raises(np.AxisError, dpt.flip, X, axis=(0, 3))


def test_flip_0d():
Expand Down