Skip to content

Commit

Permalink
Merge branch 'enhancement/836-norm' of github.com:helmholtz-analytics…
Browse files Browse the repository at this point in the history
…/heat into enhancement/836-norm
  • Loading branch information
mtar committed Aug 20, 2021
2 parents 6190b93 + 5bcaa34 commit 26a9eef
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 7 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set.
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed.
## Feature additions
### Linear Algebra
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`

## Feature Additions

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
## Manipulations
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
### Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`


# v1.1.0

Expand Down
22 changes: 19 additions & 3 deletions heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,29 @@ def isclose(
output_gshape = stride_tricks.broadcast_shape(t1.gshape, t2.gshape)
res = torch.empty(output_gshape, device=t1.device.torch_device).bool()
t1.comm.Allgather(_local_isclose, res)
result = factories.array(res, dtype=types.bool, device=t1.device, split=t1.split)
result = DNDarray(
res,
gshape=output_gshape,
dtype=types.bool,
split=t1.split,
device=t1.device,
comm=t1.comm,
balanced=t1.is_balanced,
)
else:
if _local_isclose.dim() == 0:
# both x and y are scalars, return a single boolean value
result = bool(factories.array(_local_isclose).item())
result = bool(_local_isclose.item())
else:
result = factories.array(_local_isclose, dtype=types.bool, device=t1.device)
result = DNDarray(
_local_isclose,
gshape=tuple(_local_isclose.shape),
dtype=types.bool,
split=None,
device=t1.device,
comm=t1.comm,
balanced=t1.is_balanced,
)

return result

Expand Down
116 changes: 116 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"flipud",
"hsplit",
"hstack",
"moveaxis",
"pad",
"ravel",
"redistribute",
Expand All @@ -50,6 +51,7 @@
"split",
"squeeze",
"stack",
"swapaxes",
"topk",
"unique",
"vsplit",
Expand Down Expand Up @@ -1055,6 +1057,71 @@ def hstack(arrays: Sequence[DNDarray, ...]) -> DNDarray:
return concatenate(arrays, axis=axis)


def moveaxis(
x: DNDarray, source: Union[int, Sequence[int]], destination: Union[int, Sequence[int]]
) -> DNDarray:
"""
Moves axes at the positions in `source` to new positions.
Parameters
----------
x : DNDarray
The input array.
source : int or Sequence[int, ...]
Original positions of the axes to move. These must be unique.
destination : int or Sequence[int, ...]
Destination positions for each of the original axes. These must also be unique.
See Also
--------
~heat.core.linalg.basics.transpose
Permute the dimensions of an array.
Raises
------
TypeError
If `source` or `destination` are not ints, lists or tuples.
ValueError
If `source` and `destination` do not have the same number of elements.
Examples
--------
>>> x = ht.zeros((3, 4, 5))
>>> ht.moveaxis(x, 0, -1).shape
(4, 5, 3)
>>> ht.moveaxis(x, -1, 0).shape
(5, 3, 4)
"""
if isinstance(source, int):
source = (source,)
if isinstance(source, list):
source = tuple(source)
try:
source = stride_tricks.sanitize_axis(x.shape, source)
except TypeError:
raise TypeError("'source' must be ints, lists or tuples.")

if isinstance(destination, int):
destination = (destination,)
if isinstance(destination, list):
destination = tuple(destination)
try:
destination = stride_tricks.sanitize_axis(x.shape, destination)
except TypeError:
raise TypeError("'destination' must be ints, lists or tuples.")

if len(source) != len(destination):
raise ValueError("'source' and 'destination' must have the same number of elements.")

order = [n for n in range(x.ndim) if n not in source]

for dest, src in sorted(zip(destination, source)):
order.insert(dest, src)

return linalg.transpose(x, order)


def pad(
array: DNDarray,
pad_width: Union[int, Sequence[Sequence[int, int], ...]],
Expand Down Expand Up @@ -2957,6 +3024,55 @@ def stack(
return stacked


def swapaxes(x: DNDarray, axis1: int, axis2: int) -> DNDarray:
"""
Interchanges two axes of an array.
Parameters
----------
x : DNDarray
Input array.
axis1 : int
First axis.
axis2 : int
Second axis.
See Also
--------
:func:`~heat.core.linalg.basics.transpose`
Permute the dimensions of an array.
Examples
--------
>>> x = ht.array([[[0,1],[2,3]],[[4,5],[6,7]]])
>>> ht.swapaxes(x, 0, 1)
DNDarray([[[0, 1],
[4, 5]],
[[2, 3],
[6, 7]]], dtype=ht.int64, device=cpu:0, split=None)
>>> ht.swapaxes(x, 0, 2)
DNDarray([[[0, 4],
[2, 6]],
[[1, 5],
[3, 7]]], dtype=ht.int64, device=cpu:0, split=None)
"""
axes = list(range(x.ndim))
try:
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
except TypeError:
raise TypeError(
"'axis1' and 'axis2' must be of type int, found {} and {}".format(
type(axis1), type(axis2)
)
)

return linalg.transpose(x, axes)


DNDarray.swapaxes = lambda self, axis1, axis2: swapaxes(self, axis1, axis2)
DNDarray.swapaxes.__doc__ = swapaxes.__doc__


def unique(
a: DNDarray, sorted: bool = False, return_inverse: bool = False, axis: int = None
) -> Tuple[DNDarray, torch.tensor]:
Expand Down
27 changes: 27 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,22 @@ def test_hstack(self):
res = ht.hstack((a, b))
self.assertEqual(res.shape, (24,))

def test_moveaxis(self):
a = ht.zeros((3, 4, 5))

moved = ht.moveaxis(a, 0, -1)
self.assertEquals(moved.shape, (4, 5, 3))

moved = ht.moveaxis(a, [0, 1], [-1, -2])
self.assertEquals(moved.shape, (5, 4, 3))

with self.assertRaises(TypeError):
ht.moveaxis(a, source="r", destination=3)
with self.assertRaises(TypeError):
ht.moveaxis(a, source=2, destination=3.6)
with self.assertRaises(ValueError):
ht.moveaxis(a, source=[0, 1, 2], destination=[0, 1])

def test_pad(self):
# ======================================
# test padding of non-distributed tensor
Expand Down Expand Up @@ -3240,6 +3256,17 @@ def test_stack(self):
with self.assertRaises(ValueError):
ht.stack((ht_a_split, ht_b_split, ht_c_split), out=out_wrong_split)

def test_swapaxes(self):
x = ht.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
swapped = ht.swapaxes(x, 0, 1)

self.assertTrue(
ht.equal(swapped, ht.array([[[0, 1], [4, 5]], [[2, 3], [6, 7]]], dtype=ht.int64))
)

with self.assertRaises(TypeError):
ht.swapaxes(x, 4.9, "abc")

def test_topk(self):
size = ht.MPI_WORLD.size
if size == 1:
Expand Down

0 comments on commit 26a9eef

Please sign in to comment.