Skip to content

Commit

Permalink
utilities: enhance method to recursively convert items to np.ndarray,…
Browse files Browse the repository at this point in the history
… handling hard cases for np.array
  • Loading branch information
kmantel committed Dec 12, 2017
1 parent 24a7c78 commit 6c3760e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
13 changes: 12 additions & 1 deletion psyneulink/globals/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,4 +1063,15 @@ def convert_all_elements_to_np_array(arr):
else:
return arr

return np.asarray([convert_all_elements_to_np_array(x) for x in arr])
subarr = [convert_all_elements_to_np_array(x) for x in arr]
try:
return np.array(subarr)
except ValueError:
# numpy cannot easily create arrays with subarrays of certain dimensions, workaround here
# https://stackoverflow.com/q/26885508/3131666
len_subarr = len(subarr)
elementwise_subarr = np.empty(len_subarr, dtype=np.ndarray)
for i in range(len_subarr):
elementwise_subarr[i] = subarr[i]

return elementwise_subarr
35 changes: 35 additions & 0 deletions tests/misc/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import collections
import numpy as np
import pytest

from psyneulink.globals.utilities import convert_all_elements_to_np_array


@pytest.mark.parametrize(
'arr, expected',
[
([[0], [0, 0]], np.array([np.array([0]), np.array([0, 0])])),
# should test these but numpy cannot easily create an array from them
# [np.ones((2,2)), np.zeros((2,1))]
# [np.array([[0]]), np.array([[[ 1., 1., 1.], [ 1., 1., 1.]]])]
]
)
def test_convert_all_elements_to_np_array(arr, expected):
converted = convert_all_elements_to_np_array(arr)

# no current numpy methods can test this
def check_equality_recursive(arr, expected):
if (
not isinstance(arr, collections.Iterable)
or (isinstance(arr, np.ndarray) and arr.ndim == 0)
):
assert arr == expected
else:
assert isinstance(expected, type(arr))
assert len(arr) == len(expected)

for i in range(len(arr)):
check_equality_recursive(arr[i], expected[i])

check_equality_recursive(converted, expected)

0 comments on commit 6c3760e

Please sign in to comment.