Skip to content

Commit

Permalink
Logical vector (#1177)
Browse files Browse the repository at this point in the history
* Make logical indexing consistent independent of the size of the vector.

* Add a few test cases
  • Loading branch information
dfalbel authored Jun 20, 2024
1 parent 29735f9 commit 9e3928b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
8 changes: 1 addition & 7 deletions src/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,6 @@ index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice,
}
}

// scalar boolean
if (TYPEOF(slice) == LGLSXP && LENGTH(slice) == 1) {
index_append_scalar_bool(index, slice);
return {1, false, false};
}

// the fill sybol was passed. in this case we add the ellipsis ...
if (Rf_inherits(slice, "fill")) {
index_append_ellipsis(index);
Expand All @@ -259,7 +253,7 @@ index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice,
return {1, true, false};
}

if (TYPEOF(slice) == LGLSXP && LENGTH(slice) > 1) {
if (TYPEOF(slice) == LGLSXP) {
index_append_bool_vector(index, slice);
return {1, true, false};
}
Expand Down
8 changes: 6 additions & 2 deletions tests/testthat/test-indexing.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,13 @@ test_that("subset assignment", {

test_that("indexing with R boolean vectors", {
x <- torch_tensor(c(1, 2))
expect_equal_to_r(x[TRUE], matrix(c(1, 2), nrow = 1))
expect_equal_to_r(x[FALSE], matrix(data = 1, ncol = 2, nrow = 0))
expect_equal_to_r(x[c(TRUE, FALSE)], 1)
x <- torch_tensor(c(1))
expect_equal_to_r(x[TRUE], 1)

x <- torch_zeros(2, 2)
expect_equal(dim(x[c(TRUE, FALSE),]), c(1,2))
expect_equal(dim(x[c(TRUE, FALSE),c(TRUE, FALSE)]), c(1,1))
})

test_that("indexing with long tensors", {
Expand Down

0 comments on commit 9e3928b

Please sign in to comment.