Skip to content

Commit ac45d52

Browse files
author
Holger Kohr
committed
TST: add tensor space getitem tests
1 parent 2bd1d1c commit ac45d52

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

odl/test/space/npy_tensors_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,41 @@ def test_pdist(exponent):
535535
assert x.dist(y) == pytest.approx(correct_dist)
536536

537537

538+
def test_space_getitem(getitem_indices):
539+
"""Check if space indexing works as expected."""
540+
space = odl.tensor_space((2, 3, 4), dtype=complex, exponent=1, weighting=2)
541+
542+
# Ellipsis not supported
543+
try:
544+
iter(getitem_indices)
545+
except TypeError:
546+
pass
547+
else:
548+
if Ellipsis in getitem_indices:
549+
with pytest.raises(TypeError):
550+
space[getitem_indices]
551+
return
552+
553+
sliced_space = space[getitem_indices]
554+
x = np.empty(space.shape)
555+
sliced_x = x[getitem_indices]
556+
assert sliced_space.shape == sliced_x.shape
557+
558+
assert sliced_space.exponent == space.exponent
559+
assert sliced_space.dtype == space.dtype
560+
assert sliced_space.weighting == space.weighting
561+
562+
563+
def test_space_getitem_array_weighting():
564+
"""Check that array weighting is propagated correctly when slicing."""
565+
shape = (2, 3)
566+
weight_arr = np.arange(1, np.prod(shape) + 1).reshape(shape)
567+
space = odl.tensor_space(shape, weighting=weight_arr)
568+
569+
sliced_space = space[0, ::2]
570+
assert all_equal(sliced_space.weighting.array, weight_arr[0, ::2])
571+
572+
538573
def test_element_getitem(getitem_indices):
539574
"""Check if getitem produces correct values, shape and other stuff."""
540575
space = odl.tensor_space((2, 3, 4), dtype=complex, exponent=1, weighting=2)

0 commit comments

Comments
 (0)