Skip to content

Commit 7c85781

Browse files
Defined finfo_object to not derive from numpy object
Added tests to cover each property.
1 parent 749d5e9 commit 7c85781

File tree

2 files changed

+99
-25
lines changed

2 files changed

+99
-25
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,91 @@
3131
)
3232

3333

34-
class finfo_object(np.finfo):
34+
class finfo_object:
3535
"""
3636
`numpy.finfo` subclass which returns Python floating-point scalars for
3737
`eps`, `max`, `min`, and `smallest_normal` attributes.
3838
"""
3939

4040
def __init__(self, dtype):
4141
_supported_dtype([dpt.dtype(dtype)])
42-
super().__init__()
42+
self._finfo = np.finfo(dtype)
4343

44-
self.eps = float(self.eps)
45-
self.max = float(self.max)
46-
self.min = float(self.min)
44+
@property
45+
def bits(self):
46+
"""
47+
number of bits occupied by the real-valued floating-point data type.
48+
"""
49+
return int(self._finfo.bits)
4750

4851
@property
4952
def smallest_normal(self):
50-
return float(super().smallest_normal)
53+
"""
54+
smallest positive real-valued floating-point number with full
55+
precision.
56+
"""
57+
return float(self._finfo.smallest_normal)
5158

5259
@property
5360
def tiny(self):
54-
return float(super().tiny)
61+
"""an alias for `smallest_normal`"""
62+
return float(self._finfo.tiny)
63+
64+
@property
65+
def eps(self):
66+
"""
67+
difference between 1.0 and the next smallest representable real-valued
68+
floating-point number larger than 1.0 according to the IEEE-754
69+
standard.
70+
"""
71+
return float(self._finfo.eps)
72+
73+
@property
74+
def epsneg(self):
75+
"""
76+
difference between 1.0 and the next smallest representable real-valued
77+
floating-point number smaller than 1.0 according to the IEEE-754
78+
standard.
79+
"""
80+
return float(self._finfo.epsneg)
81+
82+
@property
83+
def min(self):
84+
"""smallest representable real-valued number."""
85+
return float(self._finfo.min)
86+
87+
@property
88+
def max(self):
89+
"largest representable real-valued number."
90+
return float(self._finfo.max)
91+
92+
@property
93+
def resolution(self):
94+
"the approximate decimal resolution of this type."
95+
return float(self._finfo.resolution)
96+
97+
@property
98+
def precision(self):
99+
"""
100+
the approximate number of decimal digits to which this kind of
101+
floating point type is precise.
102+
"""
103+
return float(self._finfo.precision)
104+
105+
@property
106+
def dtype(self):
107+
"""
108+
the dtype for which finfo returns information. For complex input, the
109+
returned dtype is the associated floating point dtype for its real and
110+
complex components.
111+
"""
112+
return self._finfo.dtype
113+
114+
def __str__(self):
115+
return self._finfo.__str__()
116+
117+
def __repr__(self):
118+
return self._finfo.__repr__()
55119

56120

57121
def _broadcast_strides(X_shape, X_strides, res_ndim):

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,9 @@ def test_permute_dims_2d_3d(shapes):
8484

8585

8686
def test_expand_dims_incorrect_type():
87-
X_list = list([1, 2, 3, 4, 5])
88-
X_tuple = tuple(X_list)
89-
Xnp = np.array(X_list)
90-
91-
pytest.raises(TypeError, dpt.permute_dims, X_list, 1)
92-
pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1)
93-
pytest.raises(TypeError, dpt.permute_dims, Xnp, 1)
87+
X_list = [1, 2, 3, 4, 5]
88+
with pytest.raises(TypeError):
89+
dpt.permute_dims(X_list, 1)
9490

9591

9692
def test_expand_dims_0d():
@@ -143,22 +139,20 @@ def test_expand_dims_tuple(axes):
143139

144140

145141
def test_expand_dims_incorrect_tuple():
146-
147142
X = dpt.empty((3, 3, 3), dtype="i4")
148-
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, -6))
149-
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, 5))
143+
with pytest.raises(np.AxisError):
144+
dpt.expand_dims(X, (0, -6))
145+
with pytest.raises(np.AxisError):
146+
dpt.expand_dims(X, (0, 5))
150147

151-
pytest.raises(ValueError, dpt.expand_dims, X, (1, 1))
148+
with pytest.raises(ValueError):
149+
dpt.expand_dims(X, (1, 1))
152150

153151

154152
def test_squeeze_incorrect_type():
155-
X_list = list([1, 2, 3, 4, 5])
156-
X_tuple = tuple(X_list)
157-
Xnp = np.array(X_list)
158-
159-
pytest.raises(TypeError, dpt.permute_dims, X_list, 1)
160-
pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1)
161-
pytest.raises(TypeError, dpt.permute_dims, Xnp, 1)
153+
X_list = [1, 2, 3, 4, 5]
154+
with pytest.raises(TypeError):
155+
dpt.permute_dims(X_list, 1)
162156

163157

164158
def test_squeeze_0d():
@@ -1077,3 +1071,19 @@ def test_unstack_axis2():
10771071
assert_array_equal(dpt.asnumpy(y[:, :, 0, ...]), dpt.asnumpy(res[0]))
10781072
assert_array_equal(dpt.asnumpy(y[:, :, 1, ...]), dpt.asnumpy(res[1]))
10791073
assert_array_equal(dpt.asnumpy(y[:, :, 2, ...]), dpt.asnumpy(res[2]))
1074+
1075+
1076+
def test_finfo_object():
1077+
fi = dpt.finfo(dpt.float32)
1078+
assert isinstance(fi.bits, int)
1079+
assert isinstance(fi.max, float)
1080+
assert isinstance(fi.min, float)
1081+
assert isinstance(fi.eps, float)
1082+
assert isinstance(fi.epsneg, float)
1083+
assert isinstance(fi.smallest_normal, float)
1084+
assert isinstance(fi.tiny, float)
1085+
assert isinstance(fi.precision, float)
1086+
assert isinstance(fi.resolution, float)
1087+
assert isinstance(fi.dtype, dpt.dtype)
1088+
assert isinstance(str(fi), str)
1089+
assert isinstance(repr(fi), str)

0 commit comments

Comments
 (0)