Skip to content

Commit

Permalink
add aggr function to ParamSet, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Apr 16, 2024
1 parent 441d754 commit 8fc9f5f
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 8 deletions.
2 changes: 1 addition & 1 deletion R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
#' value upon construction.
#' @param aggr (`function`)\cr
#' Function with one argument, which is a list of parameter values.
#' The function specifies how this list of parameter values is aggregated to form one parameter value.
#' The function specifies how a list of parameter values is aggregated to form one parameter value.
#' This is used in the context of inner tuning. The default is to aggregate the values.
#'
#' @return A `Domain` object.
Expand Down
22 changes: 22 additions & 0 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,28 @@ ParamSet = R6Class("ParamSet",
x
},

#' @description
#'
#' Aggregate parameter values according to the aggregation rules.
#'
#' @param x (named `list()` of `list()`s)\cr
#' The value(s) to be aggregated. Names are parameter values.
#' The aggregation function is selected accordingly for each parameter.
#' @return (named `list()`)
aggr = function(x) {
assert_list(x, types = "list")
assert_permutation(names(x), private$.aggrs$id)
if (!(length(unique(lengths(x))) == 1L)) {
stopf("The same number of values are required for each parameter")
}
if (nrow(private$.aggrs) && !length(x[[1L]])) {
stopf("More than one value is required to aggregate them")
}
imap(x, function(value, .id) {
aggr = private$.aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value)
})
},

#' @description
#' \pkg{checkmate}-like test-function. Takes a named list.
#' Return `FALSE` if the given `$constraint` is not satisfied, `TRUE` otherwise.
Expand Down
7 changes: 3 additions & 4 deletions R/to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ to_tune = function(...) {
#' See [`mlr3::Learner`] for more information.
#' @inheritParams to_tune
#' @param aggr (`function`)\cr
#' The aggregator function that determines how to aggregate a list of parameter values into one value.
#' a single parameter value. The default is to average them.
#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value.
#' The default is to average the values and round them up.
#' @export
in_tune = function(..., aggr = NULL) {
if (is.null(aggr)) {
Expand Down Expand Up @@ -241,8 +241,7 @@ tunetoken_to_ps = function(tt, param) {

tunetoken_to_ps.InnerTuneToken = function(tt, params) {
ps = NextMethod()
browser()
ps$tags = map(ps$tags, function(tags) union(tags, "inner_tune"))
ps$tags = map(ps$tags, function(tags) union(tags, "inner_tuning"))
return(ps)
}

Expand Down
2 changes: 1 addition & 1 deletion man/Domain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/ParamSet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ParamSetCollection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/in_tune.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/test_ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,22 @@ test_that("set_values allows to unset parameters by setting them to NULL", {
param_set$set_values(.values = list(a = NULL), .insert = FALSE)
expect_identical(param_set$values, list(a = NULL))
})

test_that("aggr", {
param_set = ps(
a = p_uty(aggr = function(x) "a"),
b = p_fct(levels = c("a", "b"), aggr = function(x) "b"),
c = p_lgl(aggr = function(x) "c"),
d = p_int(aggr = function(x) "d"),
e = p_dbl(aggr = function(x) "e")
)
expect_class(param_set, "ParamSet")

vals = param_set$aggr(list(a = list(1), b = list(1), c = list(1), d = list(1), e = list(1)))
expect_equal(vals, list(a = "a", b = "b", c = "c", d = "d", e = "e"))

expect_error(param_set$aggr(1), "list")
expect_error(param_set$aggr(list(1)), "list")
expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation")
expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "More than one")
})
13 changes: 13 additions & 0 deletions tests/testthat/test_domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,16 @@ test_that("$extra_trafo flag works", {
search_space = pps$search_space()
expect_false(search_space$has_extra_trafo)
})

test_that("in_tune", {
it = in_tune(1)
expect_class(it, "InnerTuneToken")
expect_function(it$aggr)
tt = to_tune(1)
expect_equal(it$content, tt$content)
expect_equal(it$aggr(list(1, 2)), 2L)

it1 = in_tune(aggr = function(x) min(unlist(x)))
expect_equal(it1$aggr(list(1, 2)), 1)
expect_true("inner_tuning" %in% ps(a = p_dbl(1, 10))$set_values(a = in_tune())$search_space()$tags)
})

0 comments on commit 8fc9f5f

Please sign in to comment.