Skip to content

Commit 4a2a79f

Browse files
Merge pull request #1182 from IntelPython/fix-gh-1178
Use normalize_axis_tuple to normalize axis for all functions but expand_dims
2 parents 37497b2 + 80c6046 commit 4a2a79f

File tree

2 files changed

+163
-208
lines changed

2 files changed

+163
-208
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,91 @@
3131
)
3232

3333

34-
class finfo_object(np.finfo):
34+
class finfo_object:
3535
"""
3636
`numpy.finfo` subclass which returns Python floating-point scalars for
3737
`eps`, `max`, `min`, and `smallest_normal` attributes.
3838
"""
3939

4040
def __init__(self, dtype):
4141
_supported_dtype([dpt.dtype(dtype)])
42-
super().__init__()
42+
self._finfo = np.finfo(dtype)
4343

44-
self.eps = float(self.eps)
45-
self.max = float(self.max)
46-
self.min = float(self.min)
44+
@property
45+
def bits(self):
46+
"""
47+
number of bits occupied by the real-valued floating-point data type.
48+
"""
49+
return int(self._finfo.bits)
4750

4851
@property
4952
def smallest_normal(self):
50-
return float(super().smallest_normal)
53+
"""
54+
smallest positive real-valued floating-point number with full
55+
precision.
56+
"""
57+
return float(self._finfo.smallest_normal)
5158

5259
@property
5360
def tiny(self):
54-
return float(super().tiny)
61+
"""an alias for `smallest_normal`"""
62+
return float(self._finfo.tiny)
63+
64+
@property
65+
def eps(self):
66+
"""
67+
difference between 1.0 and the next smallest representable real-valued
68+
floating-point number larger than 1.0 according to the IEEE-754
69+
standard.
70+
"""
71+
return float(self._finfo.eps)
72+
73+
@property
74+
def epsneg(self):
75+
"""
76+
difference between 1.0 and the next smallest representable real-valued
77+
floating-point number smaller than 1.0 according to the IEEE-754
78+
standard.
79+
"""
80+
return float(self._finfo.epsneg)
81+
82+
@property
83+
def min(self):
84+
"""smallest representable real-valued number."""
85+
return float(self._finfo.min)
86+
87+
@property
88+
def max(self):
89+
"largest representable real-valued number."
90+
return float(self._finfo.max)
91+
92+
@property
93+
def resolution(self):
94+
"the approximate decimal resolution of this type."
95+
return float(self._finfo.resolution)
96+
97+
@property
98+
def precision(self):
99+
"""
100+
the approximate number of decimal digits to which this kind of
101+
floating point type is precise.
102+
"""
103+
return float(self._finfo.precision)
104+
105+
@property
106+
def dtype(self):
107+
"""
108+
the dtype for which finfo returns information. For complex input, the
109+
returned dtype is the associated floating point dtype for its real and
110+
complex components.
111+
"""
112+
return self._finfo.dtype
113+
114+
def __str__(self):
115+
return self._finfo.__str__()
116+
117+
def __repr__(self):
118+
return self._finfo.__repr__()
55119

56120

57121
def _broadcast_strides(X_shape, X_strides, res_ndim):
@@ -137,14 +201,12 @@ def permute_dims(X, axes):
137201
"""
138202
if not isinstance(X, dpt.usm_ndarray):
139203
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
140-
if not isinstance(axes, (tuple, list)):
141-
axes = (axes,)
204+
axes = normalize_axis_tuple(axes, X.ndim, "axes")
142205
if not X.ndim == len(axes):
143206
raise ValueError(
144207
"The length of the passed axes does not match "
145208
"to the number of usm_ndarray dimensions."
146209
)
147-
axes = normalize_axis_tuple(axes, X.ndim, "axes")
148210
newstrides = tuple(X.strides[i] for i in axes)
149211
newshape = tuple(X.shape[i] for i in axes)
150212
return dpt.usm_ndarray(
@@ -187,7 +249,8 @@ def expand_dims(X, axis):
187249
"""
188250
if not isinstance(X, dpt.usm_ndarray):
189251
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
190-
if not isinstance(axis, (tuple, list)):
252+
253+
if type(axis) not in (tuple, list):
191254
axis = (axis,)
192255

193256
out_ndim = len(axis) + X.ndim
@@ -224,8 +287,6 @@ def squeeze(X, axis=None):
224287
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
225288
X_shape = X.shape
226289
if axis is not None:
227-
if not isinstance(axis, (tuple, list)):
228-
axis = (axis,)
229290
axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
230291
new_shape = []
231292
for i, x in enumerate(X_shape):
@@ -819,12 +880,6 @@ def moveaxis(X, source, destination):
819880
if not isinstance(X, dpt.usm_ndarray):
820881
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
821882

822-
if not isinstance(source, (tuple, list)):
823-
source = (source,)
824-
825-
if not isinstance(destination, (tuple, list)):
826-
destination = (destination,)
827-
828883
source = normalize_axis_tuple(source, X.ndim, "source")
829884
destination = normalize_axis_tuple(destination, X.ndim, "destination")
830885

0 commit comments

Comments
 (0)