diff --git a/tests/manual_checks/fancy_indexing.py b/tests/manual_checks/fancy_indexing.py index e4321715a..3d6d29ea5 100644 --- a/tests/manual_checks/fancy_indexing.py +++ b/tests/manual_checks/fancy_indexing.py @@ -24,6 +24,22 @@ def masked_select(): [ 1, 8, 7], [ 8, 6, 8]]) + print(x) + print('--------------------------') + print('x[x > 50]') + print(x[x > 50]) + print('--------------------------') + print('x[x < 50]') + print(x[x < 50]) + +def masked_axis_select(): + print('Masked axis select') + print('--------------------------') + x = np.array([[ 4, 99, 2], + [ 3, 4, 99], + [ 1, 8, 7], + [ 8, 6, 8]]) + print(x) print('--------------------------') print('x[:, np.sum(x, axis = 0) > 50]') diff --git a/tests/tensor/test_fancy_indexing.nim b/tests/tensor/test_fancy_indexing.nim index 4c3a58b02..790c1d37d 100644 --- a/tests/tensor/test_fancy_indexing.nim +++ b/tests/tensor/test_fancy_indexing.nim @@ -46,6 +46,17 @@ suite "Fancy indexing": check: r == exp test "Masked selection via fancy indexing": + block: + let r = x[x >. 50] + let exp = [99, 99].toTensor() + check: r == exp + + block: + let r = x[x <. 50] + let exp = [4, 2, 3, 4, 1, 8, 7, 8, 6, 8].toTensor() + check: r == exp + + test "Masked axis selection via fancy indexing": block: # print('x[:, np.sum(x, axis = 0) > 50]') let r = x[_, x.sum(axis = 0) >. 50]