Skip to content

Commit 3fac1e2

Browse files
authored
set_threads() helper (#605)
1 parent ac8c939 commit 3fac1e2

10 files changed

+137
-3
lines changed

DESCRIPTION

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Imports:
6969
mlbench,
7070
mlr3measures (>= 0.3.0),
7171
mlr3misc (>= 0.7.0),
72+
parallelly,
7273
palmerpenguins,
7374
paradox (>= 0.6.0),
7475
uuid
@@ -189,6 +190,7 @@ Collate:
189190
'predict.R'
190191
'reexports.R'
191192
'resample.R'
193+
'set_threads.R'
192194
'task_converters.R'
193195
'worker.R'
194196
'zzz.R'

NAMESPACE

+4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ S3method(predict,Learner)
6262
S3method(print,PredictionData)
6363
S3method(rd_info,Learner)
6464
S3method(rd_info,Task)
65+
S3method(set_threads,Learner)
66+
S3method(set_threads,list)
6567
export(BenchmarkResult)
6668
export(DataBackend)
6769
export(DataBackendDataTable)
@@ -163,6 +165,8 @@ export(msrs)
163165
export(resample)
164166
export(rsmp)
165167
export(rsmps)
168+
export(set.threads.default)
169+
export(set_threads)
166170
export(tgen)
167171
export(tgens)
168172
export(tsk)

R/LearnerClassifDebug.R

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#' \item{segfault_predict:}{Probability to provokes a segfault during predict.}
1919
#' \item{predict_missing}{Ratio of predictions which will be NA.}
2020
#' \item{save_tasks:}{Saves input task in `model` slot during training and prediction.}
21+
#' \item{threads:}{Number of threads to use. Has no effect.}
2122
#' \item{x:}{Numeric tuning parameter. Has no effect.}
2223
#' }
2324
#' Note that segfaults may not be triggered on your operating system.
@@ -66,6 +67,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
6667
ParamDbl$new("segfault_predict", lower = 0, upper = 1, default = 0, tags = "predict"),
6768
ParamDbl$new("predict_missing", lower = 0, upper = 1, default = 0, tags = "predict"),
6869
ParamLgl$new("save_tasks", default = FALSE, tags = c("train", "predict")),
70+
ParamInt$new("threads", lower = 1, tags = c("train", "threads")),
6971
ParamDbl$new("x", lower = 0, upper = 1, tags = "train")
7072
)
7173
),

R/set_threads.R

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#' @title Set the Number of Threads
2+
#'
3+
#' @description
4+
#' Control the parallelism via threading while calling external packages from \CRANpkg{mlr3}.
5+
#'
6+
#' For example, the random forest implementation in package \CRANpkg{ranger} (connected
7+
#' via \CRANpkg{mlr3learners}) supports threading via OpenMP.
8+
#' The number of threads to use can be set via hyperparameter `num.threads`, and
9+
#' defaults to 1. By calling `set_threads(x, 4)` with `x` being a ranger learner, the
10+
#' hyperparameter is changed so that 4 cores are used.
11+
#'
12+
#' If the object `x` does not support threading, `x` is returned as-is.
13+
#' If applied to a list, recurses through all list elements.
14+
#'
15+
#' Note that threading is incompatible with other parallelization techniques such as forking
16+
#' via the [future::plan] [future::multicore]. For this reason all learners connected to \CRANpkg{mlr3}
17+
#' have threading disabled in their defaults.
18+
#'
19+
#' @param x (`any`)\cr
20+
#' Object to set threads for, e.g. a [Learner].
21+
#' This object is modified in-place.
22+
#' @param n (`integer(1)`)\cr
23+
#' Number of threads to use.
24+
#'
25+
#' @return Same object as input `x` (changed in-place),
26+
#' with possibly updated parameter values.
27+
#' @export
28+
set_threads = function(x, n = parallelly::availableCores()) {
29+
assert_count(n, positive = TRUE)
30+
UseMethod("set_threads")
31+
}
32+
33+
#' @rdname set_threads
34+
#' @export
35+
set.threads.default = function(x, n = parallelly::availableCores()) { # nolint
36+
x
37+
}
38+
39+
#' @rdname set_threads
40+
#' @export
41+
set_threads.Learner = function(x, n = parallelly::availableCores()) { # nolint
42+
id = x$param_set$ids(tags = "threads")
43+
if (length(id)) {
44+
x$param_set$values = insert_named(x$param_set$values, named_list(id, n))
45+
}
46+
x
47+
}
48+
49+
#' @rdname set_threads
50+
#' @export
51+
set_threads.list = function(x, n = parallelly::availableCores()) { # nolint
52+
lapply(x, set_threads, n = n)
53+
}

README.Rmd

+6
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ Also, many helpful R libraries did not exist at the time [mlr](https://github.co
142142
All user input is checked with [`checkmate`](https://cran.r-project.org/package=checkmate).
143143
Return types are documented, and mechanisms popular in base R which "simplify" the result unpredictably (e.g., `sapply()` or `drop` argument in `[.data.frame`) are avoided.
144144
* Be light on dependencies. `mlr3` requires the following packages at runtime:
145+
- [`parallelly`](https://cran.r-project.org/package=parallelly):
146+
Helper functions for parallelization.
147+
No extra recursive dependencies.
145148
- [`future.apply`](https://cran.r-project.org/package=future.apply):
146149
Resampling and benchmarking is parallelized with the [`future`](https://cran.r-project.org/package=future) abstraction interfacing many parallel backends.
147150
- [`backports`](https://cran.r-project.org/package=backports):
@@ -178,6 +181,9 @@ Also, many helpful R libraries did not exist at the time [mlr](https://github.co
178181
- [`mlbench`](https://cran.r-project.org/package=mlbench):
179182
A collection of machine learning data sets.
180183
No dependencies.
184+
- [`palmerpenguins`](https://cran.r-project.org/package=palmerpenguins):
185+
A classification data set about penguins, used on examples and provided as a
186+
toy task. No dependencies.
181187
* [Reflections](https://en.wikipedia.org/wiki/Reflection_%28computer_programming%29): Objects are queryable for properties and capabilities, allowing you to program on them.
182188
* Additional functionality that comes with extra dependencies:
183189
- To capture output, warnings and exceptions, [`evaluate`](https://cran.r-project.org/package=evaluate) and [`callr`](https://cran.r-project.org/package=callr) can be used.

README.md

+9-3
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ rr$score(measure)
148148
```
149149

150150
## task task_id learner learner_id
151-
## 1: <TaskClassif[45]> penguins <LearnerClassifRpart[34]> classif.rpart
152-
## 2: <TaskClassif[45]> penguins <LearnerClassifRpart[34]> classif.rpart
153-
## 3: <TaskClassif[45]> penguins <LearnerClassifRpart[34]> classif.rpart
151+
## 1: <TaskClassif[46]> penguins <LearnerClassifRpart[34]> classif.rpart
152+
## 2: <TaskClassif[46]> penguins <LearnerClassifRpart[34]> classif.rpart
153+
## 3: <TaskClassif[46]> penguins <LearnerClassifRpart[34]> classif.rpart
154154
## resampling resampling_id iteration prediction
155155
## 1: <ResamplingCV[19]> cv 1 <PredictionClassif[19]>
156156
## 2: <ResamplingCV[19]> cv 2 <PredictionClassif[19]>
@@ -217,6 +217,9 @@ would result in non-trivial API changes.
217217
argument in `[.data.frame`) are avoided.
218218
- Be light on dependencies. `mlr3` requires the following packages at
219219
runtime:
220+
- [`parallelly`](https://cran.r-project.org/package=parallelly):
221+
Helper functions for parallelization. No extra recursive
222+
dependencies.
220223
- [`future.apply`](https://cran.r-project.org/package=future.apply):
221224
Resampling and benchmarking is parallelized with the
222225
[`future`](https://cran.r-project.org/package=future)
@@ -248,6 +251,9 @@ would result in non-trivial API changes.
248251
Performance measures. No extra recursive dependencies.
249252
- [`mlbench`](https://cran.r-project.org/package=mlbench): A
250253
collection of machine learning data sets. No dependencies.
254+
- [`palmerpenguins`](https://cran.r-project.org/package=palmerpenguins):
255+
A classification data set about penguins, used on examples and
256+
provided as a toy task. No dependencies.
251257
- [Reflections](https://en.wikipedia.org/wiki/Reflection_%28computer_programming%29):
252258
Objects are queryable for properties and capabilities, allowing you
253259
to program on them.

inst/testthat/helper_expectations.R

+1
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ expect_learner = function(lrn, task = NULL) {
314314
checkmate::expect_choice(lrn$task_type, mlr3::mlr_reflections$task_types$type)
315315
checkmate::expect_character(lrn$packages, any.missing = FALSE, min.chars = 1L, unique = TRUE)
316316
checkmate::expect_class(lrn$param_set, "ParamSet")
317+
testthat::expect_lte(length(lrn$param_set$ids(tags = "threads")), 1L)
317318
checkmate::expect_character(lrn$properties, any.missing = FALSE, min.chars = 1L, unique = TRUE)
318319
if (is.null(private(lrn)$.train)) {
319320
checkmate::expect_function(lrn$train_internal, args = "task", nargs = 1L)

man/mlr_learners_classif.debug.Rd

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/set_threads.Rd

+45
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_set_threads.R

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
test_that("set_threads", {
2+
l1 = lrn("classif.featureless")
3+
expect_learner(set_threads(l1))
4+
5+
l2 = lrn("classif.debug")
6+
expect_null(l2$param_set$values$threads)
7+
expect_learner(set_threads(l2, 1))
8+
expect_equal(l2$param_set$values$threads, 1)
9+
10+
x = list(l1, l2)
11+
expect_list(set_threads(x, 2))
12+
expect_equal(l2$param_set$values$threads, 2)
13+
})

0 commit comments

Comments
 (0)