|
| 1 | +#' @title Set the Number of Threads |
| 2 | +#' |
| 3 | +#' @description |
| 4 | +#' Control the parallelism via threading while calling external packages from \CRANpkg{mlr3}. |
| 5 | +#' |
| 6 | +#' For example, the random forest implementation in package \CRANpkg{ranger} (connected |
| 7 | +#' via \CRANpkg{mlr3learners}) supports threading via OpenMP. |
| 8 | +#' The number of threads to use can be set via hyperparameter `num.threads`, and |
| 9 | +#' defaults to 1. By calling `set_threads(x, 4)` with `x` being a ranger learner, the |
| 10 | +#' hyperparameter is changed so that 4 cores are used. |
| 11 | +#' |
| 12 | +#' If the object `x` does not support threading, `x` is returned as-is. |
| 13 | +#' If applied to a list, recurses through all list elements. |
| 14 | +#' |
| 15 | +#' Note that threading is incompatible with other parallelization techniques such as forking |
| 16 | +#' via the [future::plan] [future::multicore]. For this reason all learners connected to \CRANpkg{mlr3} |
| 17 | +#' have threading disabled in their defaults. |
| 18 | +#' |
| 19 | +#' @param x (`any`)\cr |
| 20 | +#' Object to set threads for, e.g. a [Learner]. |
| 21 | +#' This object is modified in-place. |
| 22 | +#' @param n (`integer(1)`)\cr |
| 23 | +#' Number of threads to use. |
| 24 | +#' |
| 25 | +#' @return Same object as input `x` (changed in-place), |
| 26 | +#' with possibly updated parameter values. |
| 27 | +#' @export |
| 28 | +set_threads = function(x, n = parallelly::availableCores()) { |
| 29 | + assert_count(n, positive = TRUE) |
| 30 | + UseMethod("set_threads") |
| 31 | +} |
| 32 | + |
| 33 | +#' @rdname set_threads |
| 34 | +#' @export |
| 35 | +set.threads.default = function(x, n = parallelly::availableCores()) { # nolint |
| 36 | + x |
| 37 | +} |
| 38 | + |
| 39 | +#' @rdname set_threads |
| 40 | +#' @export |
| 41 | +set_threads.Learner = function(x, n = parallelly::availableCores()) { # nolint |
| 42 | + id = x$param_set$ids(tags = "threads") |
| 43 | + if (length(id)) { |
| 44 | + x$param_set$values = insert_named(x$param_set$values, named_list(id, n)) |
| 45 | + } |
| 46 | + x |
| 47 | +} |
| 48 | + |
| 49 | +#' @rdname set_threads |
| 50 | +#' @export |
| 51 | +set_threads.list = function(x, n = parallelly::availableCores()) { # nolint |
| 52 | + lapply(x, set_threads, n = n) |
| 53 | +} |
0 commit comments