Skip to content

Commit

Permalink
fix kfold_split_stratified when a group has 1 observation
Browse files Browse the repository at this point in the history
fixes #277

Co-Authored-By: Jouni Helske <jouni.helske@iki.fi>
  • Loading branch information
jgabry and helske committed Oct 10, 2024
1 parent 93cdae8 commit 117f030
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
7 changes: 6 additions & 1 deletion R/kfold-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ kfold_split_stratified <- function(K = 10, x = NULL) {
N <- length(x)
xids <- numeric()
for (l in 1:Nlev) {
xids <- c(xids, sample(which(x==l)))
idx <- which(x == l)
if (length(idx) > 1) {
xids <- c(xids, sample(idx))
} else {
xids <- c(xids, idx)
}
}
bins <- rep(NA, N)
bins[xids] <- rep(1:K, ceiling(N/K))[1:N]
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test_kfold_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ test_that("kfold_split_stratified works", {
y <- mtcars$cyl
fold_strat <- kfold_split_stratified(10, y)
expect_equal(range(table(fold_strat)), c(3, 4))

# test when a group has 1 observation
# https://github.com/stan-dev/loo/issues/277
y <- rep(c(1, 2, 3), times = c(20, 40, 1))
expect_silent(fold_strat <- kfold_split_stratified(5, y)) # used to be a warning before fixing issue #277
tab <- table(fold_strat, y)
expect_equal(tab[1, ], c("1" = 4, "2" = 8, "3" = 1))
for (i in 2:nrow(tab)) expect_equal(tab[i, ], c("1" = 4, "2" = 8, "3" = 0))
})

test_that("kfold_split_grouped works", {
Expand Down

0 comments on commit 117f030

Please sign in to comment.