Skip to content

Commit 17cb2eb

Browse files
committed
Make all_almost_equal work for object arrays
1 parent dc22345 commit 17cb2eb

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

odl/util/testutils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,18 @@ def all_equal(iter1, iter2):
150150

151151

152152
def all_almost_equal_array(v1, v2, ndigits):
153-
return np.allclose(v1, v2,
154-
rtol=10 ** -ndigits, atol=10 ** -ndigits,
155-
equal_nan=True)
153+
if v1.dtype == object and v2.dtype == object:
154+
return all(
155+
all_almost_equal_array(v1_i, v2_i, ndigits)
156+
for v1_i, v2_i in zip(v1, v2)
157+
)
158+
159+
try:
160+
return np.allclose(
161+
v1, v2, rtol=10 ** -ndigits, atol=10 ** -ndigits, equal_nan=True
162+
)
163+
except TypeError:
164+
return False
156165

157166

158167
def all_almost_equal(iter1, iter2, ndigits=None):
@@ -170,7 +179,7 @@ def all_almost_equal(iter1, iter2, ndigits=None):
170179

171180
if hasattr(iter1, '__array__') and hasattr(iter2, '__array__'):
172181
# Only get default ndigits if comparing arrays, need to keep `None`
173-
# otherwise for recursive calls.
182+
# otherwise for recursive calls
174183
if ndigits is None:
175184
ndigits = _ndigits(iter1, iter2, None)
176185
return all_almost_equal_array(iter1, iter2, ndigits)

0 commit comments

Comments
 (0)