Skip to content

Commit

Permalink
Merge pull request #846 from helmholtz-analytics/enhancement/836-norm
Browse files Browse the repository at this point in the history
norm implementation
  • Loading branch information
coquelin77 authored Aug 20, 2021
2 parents 634a2bd + 26a9eef commit e0af0e5
Show file tree
Hide file tree
Showing 5 changed files with 568 additions and 38 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Pending additions

## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set.
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed.

## Feature Additions

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
### Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
Expand Down
7 changes: 5 additions & 2 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,10 @@ def __reduce_op(
else:
output_shape = x.gshape
for dim in axis:
partial = partial_op(partial, dim=dim, keepdim=True)
if not (
partial.shape.numel() == 0 and partial_op.__name__ in ("local_max", "local_min")
): # no neutral element for max/min
partial = partial_op(partial, dim=dim, keepdim=True)
output_shape = output_shape[:dim] + (1,) + output_shape[dim + 1 :]
if not keepdim and not len(partial.shape) == 1:
gshape_losedim = tuple(x.gshape[dim] for dim in range(len(x.gshape)) if dim not in axis)
Expand All @@ -439,7 +442,7 @@ def __reduce_op(
balanced = True
if x.comm.is_distributed():
x.comm.Allreduce(MPI.IN_PLACE, partial, reduction_op)
elif axis is not None:
elif axis is not None and not keepdim:
down_dims = len(tuple(dim for dim in axis if dim < x.split))
split -= down_dims
balanced = x.balanced
Expand Down
Loading

0 comments on commit e0af0e5

Please sign in to comment.