Skip to content

Commit cecfdaa

Browse files
authored
Merge pull request #1110 from IntelPython/array-api-cleanup
Improvements to array API conformity
2 parents 5bfc097 + 4012039 commit cecfdaa

File tree

4 files changed

+95
-58
lines changed

4 files changed

+95
-58
lines changed

dpctl/tensor/_ctors.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,12 @@ def asarray(
472472

473473

474474
def empty(
475-
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
475+
shape,
476+
dtype=None,
477+
order="C",
478+
device=None,
479+
usm_type="device",
480+
sycl_queue=None,
476481
):
477482
"""
478483
Creates `usm_ndarray` from uninitialized USM allocation.
@@ -509,7 +514,7 @@ def empty(
509514
dtype = _get_dtype(dtype, sycl_queue)
510515
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
511516
res = dpt.usm_ndarray(
512-
sh,
517+
shape,
513518
dtype=dtype,
514519
buffer=usm_type,
515520
order=order,
@@ -650,7 +655,12 @@ def arange(
650655

651656

652657
def zeros(
653-
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
658+
shape,
659+
dtype=None,
660+
order="C",
661+
device=None,
662+
usm_type="device",
663+
sycl_queue=None,
654664
):
655665
"""
656666
Creates `usm_ndarray` with zero elements.
@@ -687,7 +697,7 @@ def zeros(
687697
dtype = _get_dtype(dtype, sycl_queue)
688698
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
689699
res = dpt.usm_ndarray(
690-
sh,
700+
shape,
691701
dtype=dtype,
692702
buffer=usm_type,
693703
order=order,
@@ -698,7 +708,12 @@ def zeros(
698708

699709

700710
def ones(
701-
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
711+
shape,
712+
dtype=None,
713+
order="C",
714+
device=None,
715+
usm_type="device",
716+
sycl_queue=None,
702717
):
703718
"""
704719
Creates `usm_ndarray` with elements of one.
@@ -734,7 +749,7 @@ def ones(
734749
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
735750
dtype = _get_dtype(dtype, sycl_queue)
736751
res = dpt.usm_ndarray(
737-
sh,
752+
shape,
738753
dtype=dtype,
739754
buffer=usm_type,
740755
order=order,
@@ -746,7 +761,7 @@ def ones(
746761

747762

748763
def full(
749-
sh,
764+
shape,
750765
fill_value,
751766
dtype=None,
752767
order="C",
@@ -805,14 +820,14 @@ def full(
805820
usm_type=usm_type,
806821
sycl_queue=sycl_queue,
807822
)
808-
return dpt.copy(dpt.broadcast_to(X, sh), order=order)
823+
return dpt.copy(dpt.broadcast_to(X, shape), order=order)
809824

810825
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
811826
usm_type = usm_type if usm_type is not None else "device"
812827
fill_value_type = type(fill_value)
813828
dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type)
814829
res = dpt.usm_ndarray(
815-
sh,
830+
shape,
816831
dtype=dtype,
817832
buffer=usm_type,
818833
order=order,
@@ -872,11 +887,11 @@ def empty_like(
872887
if device is None and sycl_queue is None:
873888
device = x.device
874889
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
875-
sh = x.shape
890+
shape = x.shape
876891
dtype = dpt.dtype(dtype)
877892
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
878893
res = dpt.usm_ndarray(
879-
sh,
894+
shape,
880895
dtype=dtype,
881896
buffer=usm_type,
882897
order=order,

dpctl/tensor/_manipulation_functions.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@
3131
)
3232

3333

34+
class finfo_object(np.finfo):
35+
"""
36+
numpy.finfo subclass which returns Python floating-point scalars for
37+
eps, max, min, and smallest_normal.
38+
"""
39+
40+
def __init__(self, dtype):
41+
_supported_dtype([dpt.dtype(dtype)])
42+
super().__init__()
43+
44+
self.eps = float(self.eps)
45+
self.max = float(self.max)
46+
self.min = float(self.min)
47+
48+
@property
49+
def smallest_normal(self):
50+
return float(super().smallest_normal)
51+
52+
@property
53+
def tiny(self):
54+
return float(super().tiny)
55+
56+
3457
def _broadcast_strides(X_shape, X_strides, res_ndim):
3558
"""
3659
Broadcasts strides to match the given dimensions;
@@ -122,46 +145,46 @@ def permute_dims(X, axes):
122145
)
123146

124147

125-
def expand_dims(X, axes):
148+
def expand_dims(X, axis):
126149
"""
127-
expand_dims(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
150+
expand_dims(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray
128151
129152
Expands the shape of an array by inserting a new axis (dimension)
130-
of size one at the position specified by axes; returns a view, if possible,
153+
of size one at the position specified by axis; returns a view, if possible,
131154
a copy otherwise with the number of dimensions increased.
132155
"""
133156
if not isinstance(X, dpt.usm_ndarray):
134157
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
135-
if not isinstance(axes, (tuple, list)):
136-
axes = (axes,)
158+
if not isinstance(axis, (tuple, list)):
159+
axis = (axis,)
137160

138-
out_ndim = len(axes) + X.ndim
139-
axes = normalize_axis_tuple(axes, out_ndim)
161+
out_ndim = len(axis) + X.ndim
162+
axis = normalize_axis_tuple(axis, out_ndim)
140163

141164
shape_it = iter(X.shape)
142-
shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim))
165+
shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim))
143166

144167
return dpt.reshape(X, shape)
145168

146169

147-
def squeeze(X, axes=None):
170+
def squeeze(X, axis=None):
148171
"""
149-
squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
172+
squeeze(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray
150173
151-
Removes singleton dimensions (axes) from X; returns a view, if possible,
174+
Removes singleton dimensions (axis) from X; returns a view, if possible,
152175
a copy otherwise, but with all or a subset of the dimensions
153176
of length 1 removed.
154177
"""
155178
if not isinstance(X, dpt.usm_ndarray):
156179
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
157180
X_shape = X.shape
158-
if axes is not None:
159-
if not isinstance(axes, (tuple, list)):
160-
axes = (axes,)
161-
axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1)
181+
if axis is not None:
182+
if not isinstance(axis, (tuple, list)):
183+
axis = (axis,)
184+
axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
162185
new_shape = []
163186
for i, x in enumerate(X_shape):
164-
if i not in axes:
187+
if i not in axis:
165188
new_shape.append(x)
166189
else:
167190
if x != 1:
@@ -222,9 +245,9 @@ def broadcast_arrays(*args):
222245
return [broadcast_to(X, shape) for X in args]
223246

224247

225-
def flip(X, axes=None):
248+
def flip(X, axis=None):
226249
"""
227-
flip(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
250+
flip(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray
228251
229252
Reverses the order of elements in an array along the given axis.
230253
The shape of the array is preserved, but the elements are reordered;
@@ -233,20 +256,20 @@ def flip(X, axes=None):
233256
if not isinstance(X, dpt.usm_ndarray):
234257
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
235258
X_ndim = X.ndim
236-
if axes is None:
259+
if axis is None:
237260
indexer = (np.s_[::-1],) * X_ndim
238261
else:
239-
axes = normalize_axis_tuple(axes, X_ndim)
262+
axis = normalize_axis_tuple(axis, X_ndim)
240263
indexer = tuple(
241-
np.s_[::-1] if i in axes else np.s_[:] for i in range(X.ndim)
264+
np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim)
242265
)
243266
return X[indexer]
244267

245268

246-
def roll(X, shift, axes=None):
269+
def roll(X, shift, axis=None):
247270
"""
248271
roll(X: usm_ndarray, shift: int or tuple or list,\
249-
axes: int or tuple or list) -> usm_ndarray
272+
axis: int or tuple or list) -> usm_ndarray
250273
251274
Rolls array elements along a specified axis.
252275
Array elements that roll beyond the last position are re-introduced
@@ -257,7 +280,7 @@ def roll(X, shift, axes=None):
257280
"""
258281
if not isinstance(X, dpt.usm_ndarray):
259282
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
260-
if axes is None:
283+
if axis is None:
261284
res = dpt.empty(
262285
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
263286
)
@@ -266,8 +289,8 @@ def roll(X, shift, axes=None):
266289
)
267290
hev.wait()
268291
return res
269-
axes = normalize_axis_tuple(axes, X.ndim, allow_duplicate=True)
270-
broadcasted = np.broadcast(shift, axes)
292+
axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True)
293+
broadcasted = np.broadcast(shift, axis)
271294
if broadcasted.ndim > 1:
272295
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
273296
shifts = {ax: 0 for ax in range(X.ndim)}
@@ -495,8 +518,7 @@ def finfo(dtype):
495518
"""
496519
if isinstance(dtype, dpt.usm_ndarray):
497520
raise TypeError("Expected dtype type, got {to}.")
498-
_supported_dtype([dpt.dtype(dtype)])
499-
return np.finfo(dtype)
521+
return finfo_object(dtype)
500522

501523

502524
def _supported_dtype(dtypes):

dpctl/tensor/_reshape.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
7575
return new_sts if valid else None
7676

7777

78-
def reshape(X, newshape, order="C", copy=None):
78+
def reshape(X, shape, order="C", copy=None):
7979
"""
80-
reshape(X: usm_ndarray, newshape: tuple, order="C") -> usm_ndarray
80+
reshape(X: usm_ndarray, shape: tuple, order="C") -> usm_ndarray
8181
8282
Reshapes given usm_ndarray into new shape. Returns a view, if possible,
8383
a copy otherwise. Memory layout of the copy is controlled by order keyword.
8484
"""
8585
if not isinstance(X, dpt.usm_ndarray):
8686
raise TypeError
87-
if not isinstance(newshape, (list, tuple)):
88-
newshape = (newshape,)
87+
if not isinstance(shape, (list, tuple)):
88+
shape = (shape,)
8989
if order in "cfCF":
9090
order = order.upper()
9191
else:
@@ -97,9 +97,9 @@ def reshape(X, newshape, order="C", copy=None):
9797
f"Keyword 'copy' not recognized. Expecting True, False, "
9898
f"or None, got {copy}"
9999
)
100-
newshape = [operator.index(d) for d in newshape]
100+
shape = [operator.index(d) for d in shape]
101101
negative_ones_count = 0
102-
for nshi in newshape:
102+
for nshi in shape:
103103
if nshi == -1:
104104
negative_ones_count = negative_ones_count + 1
105105
if (nshi < -1) or negative_ones_count > 1:
@@ -108,14 +108,14 @@ def reshape(X, newshape, order="C", copy=None):
108108
"value which can only be -1"
109109
)
110110
if negative_ones_count:
111-
v = X.size // (-np.prod(newshape))
112-
newshape = [v if d == -1 else d for d in newshape]
113-
if X.size != np.prod(newshape):
114-
raise ValueError(f"Can not reshape into {newshape}")
111+
v = X.size // (-np.prod(shape))
112+
shape = [v if d == -1 else d for d in shape]
113+
if X.size != np.prod(shape):
114+
raise ValueError(f"Can not reshape into {shape}")
115115
if X.size:
116-
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
116+
newsts = reshaped_strides(X.shape, X.strides, shape, order=order)
117117
else:
118-
newsts = (1,) * len(newshape)
118+
newsts = (1,) * len(shape)
119119
copy_required = newsts is None
120120
if copy_required and (copy is False):
121121
raise ValueError(
@@ -141,11 +141,11 @@ def reshape(X, newshape, order="C", copy=None):
141141
flat_res[i], X[np.unravel_index(i, X.shape, order=order)]
142142
)
143143
return dpt.usm_ndarray(
144-
tuple(newshape), dtype=X.dtype, buffer=flat_res, order=order
144+
tuple(shape), dtype=X.dtype, buffer=flat_res, order=order
145145
)
146146
# can form a view
147147
return dpt.usm_ndarray(
148-
newshape,
148+
shape,
149149
dtype=X.dtype,
150150
buffer=X,
151151
strides=tuple(newsts),

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def test_incompatible_shapes_raise_valueerror(shapes):
483483
assert_broadcast_arrays_raise(input_shapes[::-1])
484484

485485

486-
def test_flip_axes_incorrect():
486+
def test_flip_axis_incorrect():
487487
try:
488488
q = dpctl.SyclQueue()
489489
except dpctl.SyclQueueCreationError:
@@ -492,10 +492,10 @@ def test_flip_axes_incorrect():
492492
X_np = np.ones((4, 4))
493493
X = dpt.asarray(X_np, sycl_queue=q)
494494

495-
pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axes=1)
496-
pytest.raises(np.AxisError, dpt.flip, X, axes=2)
497-
pytest.raises(np.AxisError, dpt.flip, X, axes=-3)
498-
pytest.raises(np.AxisError, dpt.flip, X, axes=(0, 3))
495+
pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axis=1)
496+
pytest.raises(np.AxisError, dpt.flip, X, axis=2)
497+
pytest.raises(np.AxisError, dpt.flip, X, axis=-3)
498+
pytest.raises(np.AxisError, dpt.flip, X, axis=(0, 3))
499499

500500

501501
def test_flip_0d():

0 commit comments

Comments
 (0)