Skip to content

feat: Automatically set encapsulation while setting fallback #763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Imports:
backports,
checkmate (>= 2.0.0),
data.table (>= 1.14.2),
evaluate,
future,
future.apply (>= 1.5.0),
lgr (>= 0.3.4),
Expand All @@ -79,7 +80,6 @@ Suggests:
codetools,
datasets,
distr6,
evaluate,
future.callr,
mlr3data,
progressr,
Expand Down
39 changes: 29 additions & 10 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,6 @@ Learner = R6Class("Learner",
#' Also see the section on error handling the mlr3book: \url{https://mlr3book.mlr-org.com/technical.html#error-handling}
timeout = c(train = Inf, predict = Inf),

#' @field fallback ([Learner])\cr
#' Learner which is fitted to impute predictions in case that either the model fitting or the prediction of the top learner is not successful.
#' Requires you to enable encapsulation, otherwise errors are not caught and the execution is terminated before the fallback learner kicks in.
#' Also see the section on error handling the mlr3book: \url{https://mlr3book.mlr-org.com/technical.html#error-handling}
fallback = NULL,

#' @template field_man
man = NULL,

Expand All @@ -161,7 +155,6 @@ Learner = R6Class("Learner",
self$id = assert_string(id, min.chars = 1L)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
private$.encapsulate = c(train = "none", predict = "none")
self$feature_types = assert_subset(feature_types, mlr_reflections$task_feature_types)
self$predict_types = assert_subset(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]), empty.ok = FALSE)
private$.predict_type = predict_types[1L]
Expand Down Expand Up @@ -460,12 +453,37 @@ Learner = R6Class("Learner",
#' Possible values are `"none"`, `"evaluate"` (requires package \CRANpkg{evaluate}) and `"callr"` (requires package \CRANpkg{callr}).
#' See [mlr3misc::encapsulate()] for more details.
encapsulate = function(rhs) {
default = c(train = "none", predict = "none")

if (missing(rhs)) {
return(private$.encapsulate)
return(insert_named(default, private$.encapsulate))
}

assert_character(rhs)
assert_names(names(rhs), subset.of = c("train", "predict"))
private$.encapsulate = insert_named(c(train = "none", predict = "none"), rhs)
private$.encapsulate = insert_named(default, rhs)
},

#' @field fallback ([Learner])\cr
#' Learner which is fitted to impute predictions in case that either the model fitting or the prediction of the top learner is not successful.
#' Requires encapsulation, otherwise errors are not caught and the execution is terminated before the fallback learner kicks in.
#' If you have not set encapsulation manually before, setting the fallback learner automatically
#' activates encapsulation using the \CRANpkg{evaluate} package.
#' Also see the section on error handling the mlr3book: \url{https://mlr3book.mlr-org.com/technical.html#error-handling}
fallback = function(rhs) {
if (missing(rhs)) {
return(private$.fallback)
}

assert_learner(rhs, task_type = self$task_type)
if (!identical(self$predict_type, rhs$predict_type)) {
warningf("The fallback learner '%s' and the base learner '%s' have different predict types",
rhs$predict_type, self$predict_type)
}
if (is.null(private$.encapsulate)) {
private$.encapsulate = c(train = "evaluate", predict = "evaluate")
}
private$.fallback = rhs
},

#' @field hotstart_stack ([HotstartStack])\cr.
Expand All @@ -481,14 +499,15 @@ Learner = R6Class("Learner",

private = list(
.encapsulate = NULL,
.fallback = NULL,
.predict_type = NULL,
.param_set = NULL,
.hotstart_stack = NULL,

deep_clone = function(name, value) {
switch(name,
.param_set = value$clone(deep = TRUE),
fallback = if (is.null(value)) NULL else value$clone(deep = TRUE),
.fallback = if (is.null(value)) NULL else value$clone(deep = TRUE),
state = {
if (!is.null(value$train_task)) {
value$train_task = value$train_task$clone(deep = TRUE)
Expand Down
12 changes: 9 additions & 3 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ assert_tasks = function(tasks, task_type = NULL, feature_types = NULL, task_prop

#' @export
#' @param learner ([Learner]).
#' @param task_type (`character(1)`).
#' @rdname mlr_assertions
assert_learner = function(learner, task = NULL, properties = character(), .var.name = vname(learner)) {
assert_learner = function(learner, task = NULL, task_type = NULL, properties = character(), .var.name = vname(learner)) {
assert_class(learner, "Learner", .var.name = .var.name)

task_type = task_type %??% task$task_type
if (!is.null(task_type) && task_type != learner$task_type) {
stopf("Learner '%s' must have task type '%s'", learner$id, task_type)
}

if (length(properties)) {
miss = setdiff(properties, learner$properties)
if (length(miss)) {
Expand All @@ -85,8 +91,8 @@ assert_learner = function(learner, task = NULL, properties = character(), .var.n
#' @export
#' @param learners (list of [Learner]).
#' @rdname mlr_assertions
assert_learners = function(learners, task = NULL, properties = character(), .var.name = vname(learners)) {
invisible(lapply(learners, assert_learner, task = task, properties = properties, .var.name = .var.name))
assert_learners = function(learners, task = NULL, task_type = NULL, properties = character(), .var.name = vname(learners)) {
invisible(lapply(learners, assert_learner, task = task, task_type = NULL, properties = properties, .var.name = .var.name))
}

assert_task_learner = function(task, learner, cols = NULL) {
Expand Down
12 changes: 7 additions & 5 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/mlr_assertions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/test_encapsulate.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@ task = tsk("iris")
learner = lrn("classif.debug")
learner$param_set$values = list(message_train = 1, warning_train = 1, message_predict = 1, warning_predict = 1)

test_that("encapsulation is automatically enabled", {
tmp = lrn("classif.debug")
expect_equal(tmp$encapsulate, c(train = "none", predict = "none"))
expect_null(get_private(tmp)$.encapsulate)

tmp$fallback = lrn("classif.featureless")
expect_equal(tmp$encapsulate, c(train = "evaluate", predict = "evaluate"))
expect_equal(get_private(tmp)$.encapsulate, c(train = "evaluate", predict = "evaluate"))

tmp = lrn("classif.debug")
tmp$encapsulate = c(train = "none", predict = "none")
tmp$fallback = lrn("classif.featureless")
expect_equal(tmp$encapsulate, c(train = "none", predict = "none"))
expect_equal(get_private(tmp)$.encapsulate, c(train = "none", predict = "none"))
})

test_that("evaluate / single step", {
row_ids = 1:120
expect_message(expect_warning(disable_encapsulation(learner)$train(task, row_ids)))
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test_fallback.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ test_that("fail during resample", {

test_that("incomplete predictions", {
task = tsk("iris")
learner = lrn("classif.debug", predict_type = "prob", predict_missing = 0.5, fallback = lrn("classif.featureless"))
learner = lrn("classif.debug", predict_type = "prob", predict_missing = 0.5,
fallback = lrn("classif.featureless", predict_type = "prob"))

learner$train(task)
p = learner$predict(task)
Expand Down