Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/32 diff #388

Merged
merged 12 commits into from
Sep 25, 2019
Prev Previous commit
Next Next commit
n>=2 not working for first element in rank>1
  • Loading branch information
coquelin77 committed Sep 19, 2019
commit 44ec6414c143327f91ba52dfd4e3ad3fde8d12dc
13 changes: 9 additions & 4 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,23 @@ def diff(a, n=1, axis=-1):
raise TypeError('\'a\' must be a DNDarray')
rank = a.comm.rank
size = a.comm.size
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
axis_slice = [slice(None)] * len(a.shape)
axis_slice[axis] = slice(1, None, None)
axis_slice_end = [slice(None)] * len(a.shape)
axis_slice_end[axis] = slice(None, -1, None)
if not a.is_distributed():
ret = a.copy()
for _ in range(n):
axis_slice = [slice(None)] * len(ret.shape)
axis_slice[axis] = slice(1, None, None)
axis_slice_end = [slice(None)] * len(ret.shape)
axis_slice_end[axis] = slice(None, -1, None)
ret = ret[axis_slice] - ret[axis_slice_end]
return ret
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
else:
ret = a.copy()
for _ in range(n): # work loop, runs n times. using the result at the end of the loop as the starting values for each loop
axis_slice = [slice(None)] * len(ret.shape)
axis_slice[axis] = slice(1, None, None)
axis_slice_end = [slice(None)] * len(ret.shape)
axis_slice_end[axis] = slice(None, -1, None)

arb_slice = [slice(None)] * len(a.shape)
arb_slice[axis] = 0 # build the slice for the first element on the specified axis
if rank > 0:
Expand Down
42 changes: 42 additions & 0 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest

import heat as ht
import numpy as np


class TestArithmetics(unittest.TestCase):
Expand Down Expand Up @@ -48,6 +49,47 @@ def test_add(self):
with self.assertRaises(TypeError):
ht.add('T', 's')

def test_diff(self):
# tests to run:
# correctness, 1d, 2d, 3d
# axis: 0, 1, 2
# split: 0, 1, 2
# comp with numpy

ht_array = ht.random.rand(20, 20, 20, split=None)
arb_slice = [0] * 3
for dim in range(3): # loop over 3 dimensions
arb_slice[dim] = slice(None)
for ax in range(dim + 1): # loop over the possible axis values
for sp in range(dim + 1): # loop over the possible split values
for nl in range(1, 4): # loop to 3 for the number of times to do the diff
lp_array = ht_array[arb_slice].resplit(sp) # only generating the number once and then
np_array = ht_array[arb_slice].numpy()

ht_diff = ht.diff(lp_array, n=nl, axis=ax)
np_diff = ht.array(np.diff(np_array, n=nl, axis=ax))
self.assertTrue(ht.equal(ht_diff, np_diff))
self.assertEqual(ht_diff.split, sp)
self.assertEqual(ht_diff.dtype, lp_array.dtype)

# lp_array = ht.array(ht_array, split=None) # only generating the number once and then
np_array = ht_array.numpy()
ht_diff = ht.diff(ht_array, n=2)
np_diff = ht.array(np.diff(np_array, n=2))
self.assertTrue(ht.equal(ht_diff, np_diff))
self.assertEqual(ht_diff.split, None)
self.assertEqual(ht_diff.dtype, ht_array.dtype)

ht_array = ht.random.rand(20, 20, 20, split=1, dtype=ht.float64)
ht_diff = ht.diff(ht_array, n=2)
np_diff = ht.array(np.diff(np_array, n=2))
self.assertTrue(ht.equal(ht_diff, np_diff))
self.assertEqual(ht_diff.split, 1)
self.assertEqual(ht_diff.dtype, ht_array.dtype)
# raises
# with self.assertRaises()
pass

def test_div(self):
result = ht.array([
[0.5, 1.0],
Expand Down