Skip to content

Commit

Permalink
Merge pull request #1315 from helmholtz-analytics/bugs/1232-_Bug_Ensu…
Browse files Browse the repository at this point in the history
…re_NumPy-compatibility_of_test_statistics_py

Fix some NumPy deprecations in the core tests.
  • Loading branch information
mrfh92 authored Jan 22, 2024
2 parents f8fd26a + 6a5115c commit 983e12d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
5 changes: 3 additions & 2 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
import math

import heat as ht
import numpy as np
Expand Down Expand Up @@ -775,11 +776,11 @@ def test_nanprod(self):
self.assertEqual(shape_noaxis_split_nanprod.dtype, ht.float32)
self.assertEqual(shape_noaxis_split_nanprod.larray.dtype, torch.float32)
self.assertEqual(shape_noaxis_split_nanprod.split, None)
self.assertEqual(shape_noaxis_split_nanprod, np.math.factorial(10))
self.assertEqual(shape_noaxis_split_nanprod, math.factorial(10))

out_noaxis = ht.array(1, dtype=shape_noaxis_split.dtype)
ht.nanprod(shape_noaxis_split, out=out_noaxis)
self.assertEqual(out_noaxis.larray, np.math.factorial(10))
self.assertEqual(out_noaxis.larray, math.factorial(10))

def test_nansum(self):
array_len = 11
Expand Down
2 changes: 1 addition & 1 deletion heat/core/tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_asarray(self):
arr = np.array([1, 2, 3, 4])
asarr = ht.asarray(arr)

self.assertTrue(np.alltrue(np.equal(asarr.numpy(), arr)))
self.assertTrue(np.all(np.equal(asarr.numpy(), arr)))

asarr[0] = 0
if asarr.device == ht.cpu:
Expand Down
10 changes: 5 additions & 5 deletions heat/core/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_bincount(self):
res = ht.bincount(a, weights=w)
self.assertEqual(res.size, 4)
self.assertEqual(res.dtype, ht.float64)
self.assertTrue(ht.equal(res, ht.arange((4,), dtype=ht.float64)))
self.assertTrue(ht.equal(res, ht.arange(4, dtype=ht.float64)))

with self.assertRaises(ValueError):
ht.bincount(ht.array([0, 1, 2, 3], split=0), weights=ht.array([1, 2, 3, 4]))
Expand Down Expand Up @@ -1181,25 +1181,25 @@ def test_percentile(self):
# test list q and writing to output buffer
q = [0.1, 2.3, 15.9, 50.0, 84.1, 97.7, 99.9]
axis = 2
p_np = np.percentile(x_np, q, axis=axis, interpolation="lower", keepdims=True)
p_np = np.percentile(x_np, q, axis=axis, method="lower", keepdims=True)
p_ht = ht.percentile(x_ht, q, axis=axis, interpolation="lower", keepdims=True)
out = ht.empty(p_np.shape, dtype=ht.float64, split=None, device=x_ht.device)
ht.percentile(x_ht, q, axis=axis, out=out, interpolation="lower", keepdims=True)
self.assertEqual(p_ht.numpy()[5].all(), p_np[5].all())
self.assertEqual(out.numpy()[2].all(), p_np[2].all())
self.assertTrue(p_ht.shape == p_np.shape)
axis = None
p_np = np.percentile(x_np, q, axis=axis, interpolation="higher")
p_np = np.percentile(x_np, q, axis=axis, method="higher")
p_ht = ht.percentile(x_ht, q, axis=axis, interpolation="higher")
self.assertEqual(p_ht.numpy()[6], p_np[6])
self.assertTrue(p_ht.shape == p_np.shape)
p_np = np.percentile(x_np, q, axis=axis, interpolation="nearest")
p_np = np.percentile(x_np, q, axis=axis, method="nearest")
p_ht = ht.percentile(x_ht, q, axis=axis, interpolation="nearest")
self.assertEqual(p_ht.numpy()[2], p_np[2])

# test split q
q_ht = ht.array(q, split=0, comm=x_ht.comm)
p_np = np.percentile(x_np, q, axis=axis, interpolation="midpoint")
p_np = np.percentile(x_np, q, axis=axis, method="midpoint")
p_ht = ht.percentile(x_ht, q_ht, axis=axis, interpolation="midpoint")
self.assertEqual(p_ht.numpy()[4], p_np[4])

Expand Down

0 comments on commit 983e12d

Please sign in to comment.