Skip to content

Commit

Permalink
Merge branch 'master' into features/817-dot-vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
coquelin77 authored Jun 30, 2021
2 parents 48c2964 + 83ec24a commit 5e84bd0
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 57 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Example on 2 processes:
```

## Bug Fixes
- [#796](https://github.com/helmholtz-analytics/heat/pull/796) `heat.reshape(a, shape, new_split)` now always returns a distributed `DNDarray` if `new_split is not None` (inlcuding when the original input `a` is not distributed)
- [#758](https://github.com/helmholtz-analytics/heat/pull/758) Fix indexing inconsistencies in `DNDarray.__getitem__()`
- [#768](https://github.com/helmholtz-analytics/heat/pull/768) Fixed an issue where `deg2rad` and `rad2deg`are not working with the 'out' parameter.
- [#785](https://github.com/helmholtz-analytics/heat/pull/785) Removed `storage_offset` when finding the mpi buffer (`communication. MPICommunication.as_mpi_memory()`).
Expand All @@ -44,6 +45,7 @@ Example on 2 processes:
- [#790](https://github.com/helmholtz-analytics/heat/pull/790) catch incorrect device after `bcast` in `DNDarray.__getitem__`
- [#811](https://github.com/helmholtz-analytics/heat/pull/811) Fixed memory leak in `DNDarray.larray`
- [#820](https://github.com/helmholtz-analytics/heat/pull/820) `randn` values are pushed away from 0 by the minimum value the given dtype before being transformed into the Gaussian shape
- [#821](https://github.com/helmholtz-analytics/heat/pull/821) Fixed `__getitem__` handling of distributed `DNDarray` key element

## Feature additions
### Exponential
Expand All @@ -54,12 +56,17 @@ Example on 2 processes:
- [#768](https://github.com/helmholtz-analytics/heat/pull/768) New feature: unary positive and negative operations
- [#820](https://github.com/helmholtz-analytics/heat/pull/820) `dot` can handle matrix vector operation now

### Manipulations
- [#796](https://github.com/helmholtz-analytics/heat/pull/796) `DNDarray.reshape(shape)`: method now allows shape elements to be passed in as single arguments.

### Trigonometrics / Arithmetic
- [#806](https://github.com/helmholtz-analytics/heat/pull/809) New feature: `square`
- [#809](https://github.com/helmholtz-analytics/heat/pull/809) New feature: `acosh`, `asinh`, `atanh`

### Misc.
- [#761](https://github.com/helmholtz-analytics/heat/pull/761) New feature: `result_type`
- [#794](https://github.com/helmholtz-analytics/heat/pull/794) New feature: `meshgrid`
- [#821](https://github.com/helmholtz-analytics/heat/pull/821) Enhancement: it is no longer necessary to load-balance an imbalanced `DNDarray` before gathering it onto all processes. In short: `ht.resplit(array, None)` now works on imbalanced arrays as well.

# v1.0.0

Expand Down
21 changes: 11 additions & 10 deletions doc/source/tutorial_clustering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ initial centroids.
c1.balance_()
c2.balance_()
print("""Number of points assigned to c1: {}
Number of points assigned to c2: {}
Centroids = {}""".format(c1.shape[0], c2.shape[0], centroids))
print(f"Number of points assigned to c1: {c1.shape[0]} "
f"Number of points assigned to c2: {c2.shape[0]} "
f"Centroids = {centroids}")
.. code:: text
Expand Down Expand Up @@ -110,8 +110,9 @@ We can also cluster the data with kmedians. The respective advanced initial cent
c1.balance_()
c2.balance_()
print("""Number of points assigned to c1: {}
Number of points assigned to c2: {}""".format(c1.shape[0], c2.shape[0]))
print(f"Number of points assigned to c1: {c1.shape[0]}"
f"Number of points assigned to c2: {c2.shape[0]}")
Plotting the assigned clusters and the respective centroids:

.. code:: python
Expand All @@ -131,12 +132,12 @@ The Iris Dataset
------------------------------
The _iris_ dataset is a well known example for clustering analysis. It contains 4 measured features for samples from
three different types of iris flowers. A subset of 150 samples is included in formats h5, csv and netcdf in Heat,
located under 'heat/heat/datasets/data/iris.h5', and can be loaded in a distributed manner with Heat's parallel
located under 'heat/heat/datasets/iris.h5', and can be loaded in a distributed manner with Heat's parallel
dataloader

.. code:: python
iris = ht.load("heat/datasets/data/iris.csv", sep=";", split=0)
iris = ht.load("heat/datasets/iris.csv", sep=";", split=0)
Fitting the dataset with kmeans:

.. code:: python
Expand All @@ -160,6 +161,6 @@ Let's see what the results are. In theory, there are 50 samples of each of the 3
c2.balance_()
c3.balance_()
print("Number of points assigned to c1: {} \n
Number of points assigned to c2: {} \n
Number of points assigned to c3: {} ".format(c1.shape[0], c2.shape[0], c3.shape[0]))
print(f"Number of points assigned to c1: {c1.shape[0]} \n"
f"Number of points assigned to c2: {c2.shape[0]} \n"
f"Number of points assigned to c3: {c3.shape[0]}")
12 changes: 6 additions & 6 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def __complex__(self) -> DNDarray:
"""
return self.__cast(complex)

def counts_displs(self) -> Tuple[torch.Tensor, torch.Tensor]:
def counts_displs(self) -> Tuple[Tuple[int], Tuple[int]]:
"""
Returns actual counts (number of items per process) and displacements (offsets) of the DNDarray.
Does not assume load balance.
Expand All @@ -555,7 +555,7 @@ def counts_displs(self) -> Tuple[torch.Tensor, torch.Tensor]:
torch.cumsum(counts, dim=0)[:-1],
)
)
return (counts, displs)
return (tuple(counts.tolist()), tuple(displs.tolist()))
else:
raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.")

Expand Down Expand Up @@ -678,7 +678,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
""" if the key is a DNDarray and it has as many dimensions as self, then each of the entries in the 0th
dim refer to a single element. To handle this, the key is split into the torch tensors for each dimension.
This signals that advanced indexing is to be used. """
key.balance_()
key = manipulations.resplit(key.copy())
if key.ndim > 1:
key = list(key.larray.split(1, dim=1))
Expand All @@ -694,7 +693,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
lists mean advanced indexing will be used"""
h = [slice(None, None, None)] * self.ndim
if isinstance(key, DNDarray):
key.balance_()
key = manipulations.resplit(key.copy())
h[0] = key.larray.tolist()
elif isinstance(key, torch.Tensor):
Expand All @@ -709,14 +707,15 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
for i, k in enumerate(key):
if isinstance(k, DNDarray):
# extract torch tensor
k = manipulations.resplit(k.copy())
key[i] = k.larray.type(torch.int64)
key = tuple(key)

# assess final global shape
self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim)
gout_full = list(self_proxy[key].shape)

# ellipsis stuff
# ellipsis
key = list(key)
key_classes = [type(n) for n in key]
# if any(isinstance(n, ellipsis) for n in key):
Expand Down Expand Up @@ -755,6 +754,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
arr = torch.tensor([], dtype=self.__array.dtype, device=self.__array.device)
rank = self.comm.rank
counts, chunk_starts = self.counts_displs()
counts, chunk_starts = torch.tensor(counts), torch.tensor(chunk_starts)
chunk_ends = chunk_starts + counts
chunk_start = chunk_starts[rank]
chunk_end = chunk_ends[rank]
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def resplit_(self, axis: int = None):
gathered = torch.empty(
self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device
)
counts, displs, _ = self.comm.counts_displs_shape(self.shape, self.split)
counts, displs = self.counts_displs()
self.comm.Allgatherv(self.__array, (gathered, counts, displs), recv_axis=self.split)
self.__array = gathered
self.__split = axis
Expand Down
87 changes: 86 additions & 1 deletion heat/core/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import torch
import warnings

from typing import Callable, Iterable, Optional, Sequence, Tuple, Type, Union
from typing import Callable, Iterable, Optional, Sequence, Tuple, Type, Union, List

from .communication import MPI, sanitize_comm, Communication
from .devices import Device
from .dndarray import DNDarray
from .memory import sanitize_memory_layout
from .sanitation import sanitize_in, sanitize_sequence
from .stride_tricks import sanitize_axis, sanitize_shape
from .types import datatype

Expand All @@ -28,6 +29,7 @@
"full_like",
"linspace",
"logspace",
"meshgrid",
"ones",
"ones_like",
"zeros",
Expand Down Expand Up @@ -1043,6 +1045,89 @@ def logspace(
return pow(base, y).astype(dtype, copy=False)


def meshgrid(*arrays: Sequence[DNDarray], indexing: str = "xy") -> List[DNDarray]:
"""
Returns coordinate matrices from coordinate vectors.
Parameters
----------
arrays : Sequence[ DNDarray ]
one-dimensional arrays representing grid coordinates. If exactly one vector is distributed, the returned matrices will
be distributed along the axis equal to the index of this vector in the input list.
indexing : str, optional
Cartesian ‘xy’ or matrix ‘ij’ indexing of output. It is ignored if zero or one one-dimensional arrays are provided. Default: 'xy' .
Raises
------
ValueError
If `indexing` is not 'xy' or 'ij'.
ValueError
If more than one input vector is distributed.
Examples
--------
>>> x = ht.arange(4)
>>> y = ht.arange(3)
>>> xx, yy = ht.meshgrid(x,y)
>>> xx
DNDarray([[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3]], dtype=ht.int32, device=cpu:0, split=None)
>>> yy
DNDarray([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]], dtype=ht.int32, device=cpu:0, split=None)
"""
splitted = None

if indexing not in ["xy", "ij"]:
raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.")

if len(arrays) == 0:
return []

arrays = sanitize_sequence(arrays)

for idx, array in enumerate(arrays):
sanitize_in(array)
if array.split is not None:
if splitted is not None:
raise ValueError("split != None are not supported.")
splitted = idx

# pytorch does not support the indexing keyword: switch vectors
if indexing == "xy" and len(arrays) > 1:
arrays[0], arrays[1] = arrays[1], arrays[0]
if splitted == 0:
arrays[0] = arrays[0].resplit(0)
arrays[1] = arrays[1].resplit(None)
elif splitted == 1:
arrays[0] = arrays[0].resplit(None)
arrays[1] = arrays[1].resplit(0)

grids = torch.meshgrid(*(array.larray for array in arrays))

# pytorch does not support indexing keyword: switch back
if indexing == "xy" and len(arrays) > 1:
grids = list(grids)
grids[0], grids[1] = grids[1], grids[0]

shape = tuple(array.size for array in arrays)

return list(
DNDarray(
array=grid,
gshape=shape,
dtype=types.heat_type_of(grid),
split=splitted,
device=devices.sanitize_device(grid.device.type),
comm=sanitize_comm(None),
balanced=True,
)
for grid in grids
)


def ones(
shape: Union[int, Sequence[int]],
dtype: Type[datatype] = types.float32,
Expand Down
Loading

0 comments on commit 5e84bd0

Please sign in to comment.