Skip to content

Commit 3999ab0

Browse files
authored
Add timeout field to learners (#556)
1 parent 45ecb52 commit 3999ab0

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Imports:
6868
lgr (>= 0.3.4),
6969
mlbench,
7070
mlr3measures (>= 0.3.0),
71-
mlr3misc (>= 0.5.0),
71+
mlr3misc (>= 0.6.0),
7272
paradox (>= 0.4.0),
7373
uuid
7474
Suggests:

R/Learner.R

+7
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ Learner = R6Class("Learner",
120120
#' @template field_predict_sets
121121
predict_sets = "test",
122122

123+
#' @field timeout (named `numeric(2)`)\cr
124+
#' Timeout for the learner's train and predict steps, in seconds.
125+
#' This works differently for different encapsulation methods, see
126+
#' [mlr3misc::encapsulate()].
127+
#' Default is `c(train = Inf, predict = Inf)`.
128+
timeout = c(train = Inf, predict = Inf),
129+
123130
#' @field fallback ([Learner])\cr
124131
#' Learner which is fitted to impute predictions in case that either the model fitting or the prediction of the top learner is not successful.
125132
#' Requires you to enable encapsulation, otherwise errors are not caught and the execution is terminated before the fallback learner kicks in.

R/worker.R

+4-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ learner_train = function(learner, task, row_ids = NULL) {
4646
.f = train_wrapper,
4747
.args = list(learner = learner, task = task),
4848
.pkgs = learner$packages,
49-
.seed = NA_integer_
49+
.seed = NA_integer_,
50+
.timeout = learner$timeout["train"]
5051
)
5152

5253
learner$state = insert_named(learner$state, list(
@@ -142,7 +143,8 @@ learner_predict = function(learner, task, row_ids = NULL) {
142143
.f = predict_wrapper,
143144
.args = list(task = task, learner = learner),
144145
.pkgs = learner$packages,
145-
.seed = NA_integer_
146+
.seed = NA_integer_,
147+
.timeout = learner$timeout["predict"]
146148
)
147149

148150
prediction = result$result

0 commit comments

Comments
 (0)