Skip to content

Commit 1ad87b3

Browse files
pat-slarskotthoff
authored andcommitted
getResamplingIndices(): Translate inner resampling indices to outer indices (#2413)
* update getResamplingIndices() * add explicit return statement * add test checking translation of inner indices * move purrr from suggests to imports and remove explicit importing notation * no pipes * remove duplicate fields * fix description * fix typo in tests * set_names() needs to be one level higher * update NEWS.md
1 parent dd983c3 commit 1ad87b3

File tree

5 files changed

+63
-8
lines changed

5 files changed

+63
-8
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Authors@R:
3939
email = "masonagallo@gmail.com"),
4040
person(given = "Patrick",
4141
family = "Schratz",
42-
role = c("aut"),
42+
role = "aut",
4343
email = "patrick.schratz@gmail.com",
4444
comment = c(ORCID = "0000-0003-0748-6624")),
4545
person(given = "Jakob",
@@ -147,6 +147,7 @@ Imports:
147147
ggplot2,
148148
methods,
149149
parallelMap (>= 1.3),
150+
purrr,
150151
stats,
151152
stringi,
152153
survival,
@@ -223,7 +224,6 @@ Suggests:
223224
penalized (>= 0.9-47),
224225
pls,
225226
PMCMR (>= 4.1),
226-
purrr,
227227
randomForest,
228228
randomForestSRC (>= 2.7.0),
229229
ranger (>= 0.8.0),

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
## general
44
* add option to use fully predefined indices in resampling (`makeResampleDesc(fixed = TRUE)`)
55

6+
## functions - general
7+
* getResamplingIndices(inner = TRUE) now correctly returns the inner indices (before inner indices referred to the subset of the respective outer level train set)
8+
69
## learners - new
710
* classif.liquidSVM
811
* regr.liquidSVM

R/createSpatialResamplingPlots.R

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#'
33
#' @description Visualize partitioning of resample objects with spatial information.
44
#' @import ggplot2
5+
#' @importFrom purrr map_int flatten imap
56
#' @family plot
67
#' @author Patrick Schratz
78
#' @param task [Task] \cr
@@ -152,9 +153,9 @@ createSpatialResamplingPlots = function(task = NULL, resample = NULL, crs = NULL
152153
}
153154

154155
# create plot list with length = folds
155-
nfolds = purrr::map_int(resample, ~ .x$pred$instance$desc$folds)[1]
156+
nfolds = map_int(resample, ~ .x$pred$instance$desc$folds)[1]
156157

157-
plot.list.out.all = purrr::map(resample, function(.r) {
158+
plot.list.out.all = map(resample, function(.r) {
158159

159160
# bind coordinates to data
160161
data = cbind(task$env$data, task$coordinates)
@@ -163,9 +164,9 @@ createSpatialResamplingPlots = function(task = NULL, resample = NULL, crs = NULL
163164
data = sf::st_as_sf(data, coords = names(task$coordinates), crs = crs)
164165

165166
# create plot list with length = folds
166-
plot.list = purrr::map(1:(nfolds * repetitions), ~ data)
167+
plot.list = map(1:(nfolds * repetitions), ~ data)
167168

168-
plot.list.out = purrr::imap(plot.list, ~ ggplot(.x) +
169+
plot.list.out = imap(plot.list, ~ ggplot(.x) +
169170
geom_sf(data = subset(.x, as.integer(rownames(.x)) %in%
170171
.r$pred$instance[["train.inds"]][[.y]]),
171172
color = color.train, size = point.size, ) +
@@ -183,7 +184,7 @@ createSpatialResamplingPlots = function(task = NULL, resample = NULL, crs = NULL
183184
return(plot.list.out)
184185
})
185186

186-
plot.list = purrr::flatten(plot.list.out.all)
187+
plot.list = flatten(plot.list.out.all)
187188

188189
# more than 1 repetition?
189190
if (repetitions > 1) {

R/getResamplingIndices.R

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#' with `resample(..., extract = getTuneResult)` or `resample(..., extract = getFeatSelResult)` this helper returns a `list` with
66
#' the resampling indices used for the respective method.
77
#'
8+
#' @importFrom purrr map flatten set_names
89
#' @param object ([ResampleResult]) \cr
910
#' The result of resampling of a tuning or feature selection wrapper.
1011
#' @param inner ([logical]) \cr
@@ -38,7 +39,32 @@ getResamplingIndices = function(object, inner = FALSE) {
3839
stopf("No object of class 'TuneResult' or 'FeatuSelResult' found in slot 'extract'.
3940
Did you run 'resample()' with 'extract = getTuneResult' or 'extract = getFeatSelResult'?")
4041
}
41-
lapply(object$extract, function(x) x$resampling[c("train.inds", "test.inds")])
42+
inner_inds = lapply(object$extract, function(x) x$resampling[c("train.inds", "test.inds")])
43+
44+
outer_inds = object$pred$instance[c("train.inds", "test.inds")]
45+
46+
# now translate the inner inds back to the outer inds so we have the correct indices https://github.com/mlr-org/mlr/issues/2409
47+
48+
inner_inds_translated = map(1:length(inner_inds), function(z) # map over number of outer folds
49+
50+
set_names(
51+
map(c("train.inds", "test.inds"), function(u) # map over train/test level
52+
53+
# list() -> create list for "train.inds" and "test.inds"
54+
# flatten() -> reduce by one level
55+
# set_names(c("train.inds", "test.inds")) -> now set list names
56+
flatten(
57+
map(inner_inds[[z]][[u]], ~ # map over number of inner folds
58+
list(outer_inds[["train.inds"]][[z]][.x]) # the inner test.inds are a subset of the outer train.inds! That's why "train.inds" is hardcoded here
59+
)
60+
)
61+
),
62+
c("train.inds", "test.inds")
63+
)
64+
)
65+
66+
return(inner_inds_translated)
67+
4268
} else {
4369
object$pred$instance[c("train.inds", "test.inds")]
4470
}

tests/testthat/test_base_resample_getResamplingIndices.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,28 @@ test_that("getResamplingIndices works with getFeatSelResult", {
4545
# check if inner test.inds are retrieved correctly
4646
expect_length(unique(unlist(getResamplingIndices(r, inner = TRUE)[[1]]$test.inds)), 25)
4747
})
48+
49+
test_that("getResamplingIndices(inner = TRUE) correctly translates the inner inds to indices of the task", {
50+
51+
# this test is from "test_base_fixed_indices_cv.R"
52+
df = multiclass.df
53+
fixed = as.factor(rep(1:5, rep(30, 5)))
54+
ct = makeClassifTask(target = multiclass.target, data = df, blocking = fixed)
55+
lrn = makeLearner("classif.lda")
56+
ctrl = makeTuneControlRandom(maxit = 2)
57+
ps = makeParamSet(makeNumericParam("nu", lower = 2, upper = 20))
58+
inner = makeResampleDesc("CV", iters = 4, fixed = TRUE)
59+
outer = makeResampleDesc("CV", iters = 5, fixed = TRUE)
60+
tune_wrapper = makeTuneWrapper(lrn, resampling = inner, par.set = ps,
61+
control = ctrl, show.info = FALSE)
62+
p = resample(tune_wrapper, ct, outer, show.info = FALSE,
63+
extract = getTuneResult)
64+
65+
inner_inds = getResamplingIndices(p, inner = TRUE)
66+
67+
# to test we expect that any inner fold contains indices that exceed $obs - (obs / nfolds)$ = 150 - 30 = 120
68+
# 120 is the max index number that is used in the inner resampling (in the case of 150 obs and 5 folds) because we have one fold less than in the outer level
69+
inds = sort(inner_inds[[2]][["test.inds"]][[1]])
70+
71+
expect_equal(length(inds[inds > 120]), 30)
72+
})

0 commit comments

Comments
 (0)