Skip to content

Commit

Permalink
Merge branch 'master' into enhancement/839-vecdot
Browse files Browse the repository at this point in the history
  • Loading branch information
coquelin77 authored Aug 2, 2021
2 parents a94c884 + f0afedf commit cd6881a
Show file tree
Hide file tree
Showing 3 changed files with 365 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension

## Feature Additions

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
## Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`

# v1.1.0

Expand Down
168 changes: 168 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"repeat",
"reshape",
"resplit",
"roll",
"rot90",
"row_stack",
"shape",
Expand Down Expand Up @@ -1908,6 +1909,173 @@ def reshape_argsort_counts_displs(
DNDarray.reshape.__doc__ = reshape.__doc__


def roll(
x: DNDarray, shift: Union[int, Tuple[int]], axis: Optional[Union[int, Tuple[int]]] = None
) -> DNDarray:
"""
Rolls array elements along a specified axis. Array elements that roll beyond the last position are re-introduced at the first position.
Array elements that roll beyond the first position are re-introduced at the last position.
Parameters
----------
x : DNDarray
input array
shift : Union[int, Tuple[int, ...]]
number of places by which the elements are shifted. If 'shift' is a tuple, then 'axis' must be a tuple of the same size, and each of
the given axes is shifted by the corrresponding element in 'shift'. If 'shift' is an `int` and 'axis' a `tuple`, then the same shift
is used for all specified axes.
axis : Optional[Union[int, Tuple[int, ...]]]
axis (or axes) along which elements to shift. If 'axis' is `None`, the array is flattened, shifted, and then restored to its original shape.
Default: `None`.
Raises
------
TypeError
If 'shift' or 'axis' is not of type `int`, `list` or `tuple`.
ValueError
If 'shift' and 'axis' are tuples with different sizes.
Examples
--------
>>> a = ht.arange(20).reshape((4,5))
>>> a
DNDarray([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]], dtype=ht.int32, device=cpu:0, split=None)
>>> ht.roll(a, 1)
DNDarray([[19, 0, 1, 2, 3],
[ 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13],
[14, 15, 16, 17, 18]], dtype=ht.int32, device=cpu:0, split=None)
>>> ht.roll(a, -1, 0)
DNDarray([[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[ 0, 1, 2, 3, 4]], dtype=ht.int32, device=cpu:0, split=None)
"""
sanitation.sanitize_in(x)

if axis is None:
return roll(x.flatten(), shift, 0).reshape(x.shape, new_split=x.split)

# inputs are ints
if isinstance(shift, int):
if isinstance(axis, int):
if x.split is not None and (axis == x.split or (axis + x.ndim) == x.split):
# roll along split axis
size = x.comm.Get_size()
rank = x.comm.Get_rank()

# local elements along axis:
lshape_map = x.create_lshape_map(force_check=False)[:, x.split]
cumsum_map = torch.cumsum(lshape_map, dim=0) # cumulate along axis
indices = torch.arange(size, device=x.device.torch_device)
# NOTE Can be removed when min version>=1.9
if "1.7." in torch.__version__ or "1.8." in torch.__version__:
lshape_map = lshape_map.to(torch.int64)
index_map = torch.repeat_interleave(indices, lshape_map) # index -> process

# compute index positions
index_old = torch.arange(lshape_map[rank], device=x.device.torch_device)
if rank > 0:
index_old += cumsum_map[rank - 1]

send_index = (index_old + shift) % x.gshape[x.split]
recv_index = (index_old - shift) % x.gshape[x.split]

# exchange arrays
recv = torch.empty_like(x.larray)
recv_splits = torch.split(recv, 1, dim=x.split)
recv_requests = [None for i in range(x.lshape[x.split])]

for i in range(x.lshape[x.split]):
recv_requests[i] = x.comm.Irecv(
recv_splits[i], index_map[recv_index[i]], index_old[i]
)

send_splits = torch.split(x.larray, 1, dim=x.split)
send_requests = [None for i in range(x.lshape[x.split])]

for i in range(x.lshape[x.split]):
send_requests[i] = x.comm.Isend(
send_splits[i], index_map[send_index[i]], send_index[i]
)

for i in range(x.lshape[x.split]):
recv_requests[i].Wait()
for i in range(x.lshape[x.split]):
send_requests[i].Wait()

return DNDarray(recv, x.gshape, x.dtype, x.split, x.device, x.comm, x.balanced)

else: # pytorch does not support int / sequence combo at the time, make shift a list instead
try:
axis = sanitation.sanitize_sequence(axis)
except TypeError:
raise TypeError("axis must be a int, list or a tuple, got {}".format(type(axis)))

shift = [shift] * len(axis)

return roll(x, shift, axis)

else: # input must be tuples now
try:
shift = sanitation.sanitize_sequence(shift)
except TypeError:
raise TypeError("shift must be an integer, list or a tuple, got {}".format(type(shift)))

try:
axis = sanitation.sanitize_sequence(axis)
except TypeError:
raise TypeError("axis must be an integer, list or a tuple, got {}".format(type(axis)))

if len(shift) != len(axis):
raise ValueError(
"shift and axis length must be the same, got {} and {}".format(
len(shift), len(axis)
)
)

for i in range(len(shift)):
if not isinstance(shift[i], int):
raise TypeError(
"Element {} in shift is not an integer, got {}".format(i, type(shift[i]))
)
if not isinstance(axis[i], int):
raise TypeError(
"Element {} in axis is not an integer, got {}".format(i, type(axis[i]))
)

if x.split is not None and (x.split in axis or (x.split - x.ndim) in axis):
# remove split axis elements
shift_split = 0
for y in (x.split, x.split - x.ndim):
idx = [i for i in range(len(axis)) if axis[i] == y]
for i in idx:
shift_split += shift[i]
for i in reversed(idx):
axis.remove(y)
del shift[i]

# compute new array along split axis
x = roll(x, shift_split, x.split)
if len(axis) == 0:
return x

# use PyTorch for all other axes
rolled = torch.roll(x.larray, shift, axis)
return DNDarray(
rolled,
gshape=x.shape,
dtype=x.dtype,
split=x.split,
device=x.device,
comm=x.comm,
balanced=x.balanced,
)


def rot90(m: DNDarray, k: int = 1, axes: Sequence[int, int] = (0, 1)) -> DNDarray:
"""
Rotate an array by 90 degrees in the plane specified by `axes`.
Expand Down
194 changes: 194 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,6 +2306,200 @@ def test_reshape(self):
with self.assertRaises(TypeError):
ht.reshape(ht.zeros((4, 3)), (3.4, 3.2))

def test_roll(self):
# no split
# vector
a = ht.arange(5)
rolled = ht.roll(a, 1)
compare = ht.array([4, 0, 1, 2, 3])

self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(ht.equal(rolled, compare))

rolled = ht.roll(a, -1)
compare = ht.array([1, 2, 3, 4, 0])

self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(ht.equal(rolled, compare))

# matrix
a = ht.arange(20.0).reshape((4, 5))

rolled = ht.roll(a, -1)
compare = torch.roll(a.larray, -1)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(torch.equal(rolled.larray, compare))

rolled = ht.roll(a, 1, 0)
compare = torch.roll(a.larray, 1, 0)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(torch.equal(rolled.larray, compare))

rolled = ht.roll(a, -2, (0, 1))
compare = np.roll(a.larray.cpu().numpy(), -2, (0, 1))
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.larray.cpu().numpy(), compare))

rolled = ht.roll(a, (1, 2, 1), (0, 1, -2))
compare = torch.roll(a.larray, (1, 2, 1), (0, 1, -2))
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(torch.equal(rolled.larray, compare))

# split
# vector
a = ht.arange(5, dtype=ht.uint8, split=0)
rolled = ht.roll(a, 1)
compare = ht.array([4, 0, 1, 2, 3], dtype=ht.uint8, split=0)

self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(ht.equal(rolled, compare))

rolled = ht.roll(a, -1)
compare = ht.array([1, 2, 3, 4, 0], ht.uint8, split=0)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(ht.equal(rolled, compare))

# matrix
a = ht.arange(20).reshape((4, 5), dtype=ht.int16, new_split=0)

rolled = ht.roll(a, -1)
compare = np.roll(a.numpy(), -1)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, 1, 0)
compare = np.roll(a.numpy(), 1, 0)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, -2, (0, 1))
compare = np.roll(a.numpy(), -2, (0, 1))
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, (1, 2, 1), (0, 1, -2))
compare = np.roll(a.numpy(), (1, 2, 1), (0, 1, -2))
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

a = ht.arange(20, dtype=ht.complex64).reshape((4, 5), new_split=1)

rolled = ht.roll(a, -1)
compare = np.roll(a.numpy(), -1)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, 1, 0)
compare = np.roll(a.numpy(), 1, 0)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, -2, [0, 1])
compare = np.roll(a.numpy(), -2, [0, 1])
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, [1, 2, 1], [0, 1, -2])
compare = np.roll(a.numpy(), [1, 2, 1], [0, 1, -2])
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

# added 3D test, only a quick test for functionality
a = ht.arange(4 * 5 * 6, dtype=ht.complex64).reshape((4, 5, 6), new_split=2)

rolled = ht.roll(a, -1)
compare = np.roll(a.numpy(), -1)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, 1, 0)
compare = np.roll(a.numpy(), 1, 0)
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, -2, [0, 1])
compare = np.roll(a.numpy(), -2, [0, 1])
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

rolled = ht.roll(a, [1, 2, 1], [0, 1, -2])
compare = np.roll(a.numpy(), [1, 2, 1], [0, 1, -2])
self.assertEqual(rolled.device, a.device)
self.assertEqual(rolled.size, a.size)
self.assertEqual(rolled.dtype, a.dtype)
self.assertEqual(rolled.split, a.split)
self.assertTrue(np.array_equal(rolled.numpy(), compare))

with self.assertRaises(TypeError):
ht.roll(a, 1.0, 0)
with self.assertRaises(TypeError):
ht.roll(a, 1, 1.0)
with self.assertRaises(TypeError):
ht.roll(a, 1, (1.0, 0.0))
with self.assertRaises(TypeError):
ht.roll(a, (-1, 1), 0.0)
with self.assertRaises(TypeError):
ht.roll(a, (-1.0, 1.0), (0, 0))
with self.assertRaises(ValueError):
ht.roll(a, [1, 1, 1], [0, 0])

def test_rot90(self):
size = ht.MPI_WORLD.size
m = ht.arange(size ** 3, dtype=ht.int).reshape((size, size, size))
Expand Down

0 comments on commit cd6881a

Please sign in to comment.