Skip to content

Commit

Permalink
Merge pull request #581 from helmholtz-analytics/enhancement/579-tolist
Browse files Browse the repository at this point in the history
tolist() implementation
  • Loading branch information
Markus-Goetz authored Jun 16, 2020
2 parents 547d33f + fd9210b commit 4b00106
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray
- [#578](https://github.com/helmholtz-analytics/heat/pull/578) Bugfix: Bad variable in reshape
- [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr()
- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New Feature: DNDarray.tolist()
- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature arctan2()
- [#594](https://github.com/helmholtz-analytics/heat/pull/594) New feature: Advanced indexing
- [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: getitem and setitem memory consumption heavily reduced
Expand Down
33 changes: 33 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import numpy as np
import math
import torch
import warnings
from typing import List

from . import arithmetics
from . import devices
Expand Down Expand Up @@ -3459,6 +3462,36 @@ def tanh(self, out=None):
"""
return trigonometrics.tanh(self, out)

def tolist(self, keepsplit=False) -> List:
"""
Return a copy of the local array data as a (nested) Python list. For scalars, a standard Python number is returned.
Parameters
----------
keepsplit: bool
Whether the list should be returned locally or globally.
Examples
--------
>>> a = ht.array([[0,1],[2,3]])
>>> a.tolist()
[[0, 1], [2, 3]]
>>> a = ht.array([[0,1],[2,3]], split=0)
>>> a.tolist()
[[0, 1], [2, 3]]
>>> a = ht.array([[0,1],[2,3]], split=1)
>>> a.tolist(keepsplit=True)
(1/2) [[0], [2]]
(2/2) [[1], [3]]
"""

if not keepsplit:
return manipulations.resplit(self, axis=None).__array.tolist()

return self.__array.tolist()

def transpose(self, axes=None):
"""
Permute the dimensions of an array.
Expand Down
35 changes: 35 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,41 @@ def test_stride_and_strides(self):
)
self.assertEqual(heat_float64_F_split.strides, numpy_float64_F_split_strides)

def test_tolist(self):
a = ht.zeros([ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.int32)
res = [
[[0 for z in range(ht.MPI_WORLD.size)] for y in range(ht.MPI_WORLD.size)]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(), res)

a = ht.zeros(
[ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.int32, split=0
)
res = [
[[0 for z in range(ht.MPI_WORLD.size)] for y in range(ht.MPI_WORLD.size)]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(), res)

a = ht.zeros(
[ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.float32, split=1
)
res = [
[[0.0 for z in range(ht.MPI_WORLD.size)] for y in [ht.MPI_WORLD.rank]]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(keepsplit=True), res)

a = ht.zeros(
[ht.MPI_WORLD.size, ht.MPI_WORLD.size, ht.MPI_WORLD.size], dtype=ht.bool, split=2
)
res = [
[[False for z in [ht.MPI_WORLD.rank]] for y in range(ht.MPI_WORLD.size)]
for x in range(ht.MPI_WORLD.size)
]
self.assertListEqual(a.tolist(keepsplit=True), res)

def test_xor(self):
int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16, device=ht_device)
int16_vector = ht.array([[3, 4]], dtype=ht.int16, device=ht_device)
Expand Down

0 comments on commit 4b00106

Please sign in to comment.