|
31 | 31 | )
|
32 | 32 |
|
33 | 33 |
|
34 |
| -class finfo_object(np.finfo): |
| 34 | +class finfo_object: |
35 | 35 | """
|
36 | 36 | `numpy.finfo` subclass which returns Python floating-point scalars for
|
37 | 37 | `eps`, `max`, `min`, and `smallest_normal` attributes.
|
38 | 38 | """
|
39 | 39 |
|
40 | 40 | def __init__(self, dtype):
|
41 | 41 | _supported_dtype([dpt.dtype(dtype)])
|
42 |
| - super().__init__() |
| 42 | + self._finfo = np.finfo(dtype) |
43 | 43 |
|
44 |
| - self.eps = float(self.eps) |
45 |
| - self.max = float(self.max) |
46 |
| - self.min = float(self.min) |
| 44 | + @property |
| 45 | + def bits(self): |
| 46 | + """ |
| 47 | + number of bits occupied by the real-valued floating-point data type. |
| 48 | + """ |
| 49 | + return int(self._finfo.bits) |
47 | 50 |
|
48 | 51 | @property
|
49 | 52 | def smallest_normal(self):
|
50 |
| - return float(super().smallest_normal) |
| 53 | + """ |
| 54 | + smallest positive real-valued floating-point number with full |
| 55 | + precision. |
| 56 | + """ |
| 57 | + return float(self._finfo.smallest_normal) |
51 | 58 |
|
52 | 59 | @property
|
53 | 60 | def tiny(self):
|
54 |
| - return float(super().tiny) |
| 61 | + """an alias for `smallest_normal`""" |
| 62 | + return float(self._finfo.tiny) |
| 63 | + |
| 64 | + @property |
| 65 | + def eps(self): |
| 66 | + """ |
| 67 | + difference between 1.0 and the next smallest representable real-valued |
| 68 | + floating-point number larger than 1.0 according to the IEEE-754 |
| 69 | + standard. |
| 70 | + """ |
| 71 | + return float(self._finfo.eps) |
| 72 | + |
| 73 | + @property |
| 74 | + def epsneg(self): |
| 75 | + """ |
| 76 | + difference between 1.0 and the next smallest representable real-valued |
| 77 | + floating-point number smaller than 1.0 according to the IEEE-754 |
| 78 | + standard. |
| 79 | + """ |
| 80 | + return float(self._finfo.epsneg) |
| 81 | + |
| 82 | + @property |
| 83 | + def min(self): |
| 84 | + """smallest representable real-valued number.""" |
| 85 | + return float(self._finfo.min) |
| 86 | + |
| 87 | + @property |
| 88 | + def max(self): |
| 89 | + "largest representable real-valued number." |
| 90 | + return float(self._finfo.max) |
| 91 | + |
| 92 | + @property |
| 93 | + def resolution(self): |
| 94 | + "the approximate decimal resolution of this type." |
| 95 | + return float(self._finfo.resolution) |
| 96 | + |
| 97 | + @property |
| 98 | + def precision(self): |
| 99 | + """ |
| 100 | + the approximate number of decimal digits to which this kind of |
| 101 | + floating point type is precise. |
| 102 | + """ |
| 103 | + return float(self._finfo.precision) |
| 104 | + |
| 105 | + @property |
| 106 | + def dtype(self): |
| 107 | + """ |
| 108 | + the dtype for which finfo returns information. For complex input, the |
| 109 | + returned dtype is the associated floating point dtype for its real and |
| 110 | + complex components. |
| 111 | + """ |
| 112 | + return self._finfo.dtype |
| 113 | + |
| 114 | + def __str__(self): |
| 115 | + return self._finfo.__str__() |
| 116 | + |
| 117 | + def __repr__(self): |
| 118 | + return self._finfo.__repr__() |
55 | 119 |
|
56 | 120 |
|
57 | 121 | def _broadcast_strides(X_shape, X_strides, res_ndim):
|
@@ -137,14 +201,12 @@ def permute_dims(X, axes):
|
137 | 201 | """
|
138 | 202 | if not isinstance(X, dpt.usm_ndarray):
|
139 | 203 | raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
|
140 |
| - if not isinstance(axes, (tuple, list)): |
141 |
| - axes = (axes,) |
| 204 | + axes = normalize_axis_tuple(axes, X.ndim, "axes") |
142 | 205 | if not X.ndim == len(axes):
|
143 | 206 | raise ValueError(
|
144 | 207 | "The length of the passed axes does not match "
|
145 | 208 | "to the number of usm_ndarray dimensions."
|
146 | 209 | )
|
147 |
| - axes = normalize_axis_tuple(axes, X.ndim, "axes") |
148 | 210 | newstrides = tuple(X.strides[i] for i in axes)
|
149 | 211 | newshape = tuple(X.shape[i] for i in axes)
|
150 | 212 | return dpt.usm_ndarray(
|
@@ -187,7 +249,8 @@ def expand_dims(X, axis):
|
187 | 249 | """
|
188 | 250 | if not isinstance(X, dpt.usm_ndarray):
|
189 | 251 | raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
|
190 |
| - if not isinstance(axis, (tuple, list)): |
| 252 | + |
| 253 | + if type(axis) not in (tuple, list): |
191 | 254 | axis = (axis,)
|
192 | 255 |
|
193 | 256 | out_ndim = len(axis) + X.ndim
|
@@ -224,8 +287,6 @@ def squeeze(X, axis=None):
|
224 | 287 | raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
|
225 | 288 | X_shape = X.shape
|
226 | 289 | if axis is not None:
|
227 |
| - if not isinstance(axis, (tuple, list)): |
228 |
| - axis = (axis,) |
229 | 290 | axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
|
230 | 291 | new_shape = []
|
231 | 292 | for i, x in enumerate(X_shape):
|
@@ -819,12 +880,6 @@ def moveaxis(X, source, destination):
|
819 | 880 | if not isinstance(X, dpt.usm_ndarray):
|
820 | 881 | raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
|
821 | 882 |
|
822 |
| - if not isinstance(source, (tuple, list)): |
823 |
| - source = (source,) |
824 |
| - |
825 |
| - if not isinstance(destination, (tuple, list)): |
826 |
| - destination = (destination,) |
827 |
| - |
828 | 883 | source = normalize_axis_tuple(source, X.ndim, "source")
|
829 | 884 | destination = normalize_axis_tuple(destination, X.ndim, "destination")
|
830 | 885 |
|
|
0 commit comments