Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tolist() implementation #581

Merged
merged 14 commits into from
Jun 16, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray
- [#578](https://github.com/helmholtz-analytics/heat/pull/578) Bugfix: Bad variable in reshape
- [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr()
- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New Feature: DNDarray.tolist()
- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature arctan2()
- [#594](https://github.com/helmholtz-analytics/heat/pull/594) New feature: Advanced indexing
- [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: getitem and setitem memory consumption heavily reduced
Expand Down
33 changes: 33 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import numpy as np
import math
import torch
import warnings
from typing import List

from . import arithmetics
from . import devices
Expand Down Expand Up @@ -3459,6 +3462,36 @@ def tanh(self, out=None):
"""
return trigonometrics.tanh(self, out)

def tolist(self, keepsplit=False) -> List:
"""
Return a copy of the local array data as a (nested) Python list. For scalars, a standard Python number is returned.

Parameters
----------
keepsplit: bool
Whether the list should be returned locally or globally.

Examples
--------
>>> a = ht.array([[0,1],[2,3]])
>>> a.tolist()
[[0, 1], [2, 3]]

>>> a = ht.array([[0,1],[2,3]], split=0)
>>> a.tolist()
[[0, 1], [2, 3]]

>>> a = ht.array([[0,1],[2,3]], split=1)
>>> a.tolist(keepsplit=True)
(1/2) [[0], [2]]
(2/2) [[1], [3]]
"""

if not keepsplit:
return manipulations.resplit(self, axis=None).__array.tolist()

return self.__array.tolist()
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved

def transpose(self, axes=None):
"""
Permute the dimensions of an array.
Expand Down
35 changes: 35 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,41 @@ def test_stride_and_strides(self):
)
self.assertEqual(heat_float64_F_split.strides, numpy_float64_F_split_strides)

def test_tolist(self):
a = ht.zeros([ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.int32)
res = [
[[0 for z in range(ht.MPI_WORLD.size)] for y in range(ht.MPI_WORLD.size)]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(), res)

a = ht.zeros(
[ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.int32, split=0
)
res = [
[[0 for z in range(ht.MPI_WORLD.size)] for y in range(ht.MPI_WORLD.size)]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(), res)

a = ht.zeros(
[ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.float32, split=1
)
res = [
[[0.0 for z in range(ht.MPI_WORLD.size)] for y in [ht.MPI_WORLD.rank]]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(keepsplit=True), res)

a = ht.zeros(
[ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.bool, split=2
)
res = [
[[False for z in [ht.MPI_WORLD.rank]] for y in range(ht.MPI_WORLD.size)]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(keepsplit=True), res)

def test_xor(self):
int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16, device=ht_device)
int16_vector = ht.array([[3, 4]], dtype=ht.int16, device=ht_device)
Expand Down