Skip to content

Commit

Permalink
some more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed May 30, 2024
1 parent 773d692 commit 1b391f7
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 27 deletions.
5 changes: 3 additions & 2 deletions R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,13 @@
#'
#' param_set = ps(
#' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))),
#' in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(other_param = FALSE))
#' in_tune_fn = function(domain, param_set) domain$upper,
#' disable_in_tune = list(other_param = FALSE))
#' )
#' param_set$set_values(
#' iters = to_tune(upper = 100, internal = TRUE)
#' )
#' param_set$convert_internal_tune_tokens()
#' param_set$convert_internal_search_space(param_set$search_space())
#' param_set$aggr(list(iters = list(1, 2, 3)))
#'
#' @family ParamSet construction helpers
Expand Down
15 changes: 7 additions & 8 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,16 @@ ParamSet = R6Class("ParamSet",
},

#' @description
#' Convert all `InternalTuneToken`s to parameter values as is defined by their `in_tune_fn`.
#'
#' Convert all parameters from the search space to parameter values using the transformation given by
#' `in_tune_fn`.
#' @param search_space ([`ParamSet`])\cr
#' The internal search space.
#' @return (named `list()`)
convert_internal_tune_tokens = function() {
internal_tune_tokens = self$get_values(type = "with_internal", check_required = FALSE)
internal_tune_ps = private$get_tune_ps(internal_tune_tokens)

imap(internal_tune_ps$domains, function(token, .id) {
convert_internal_search_space = function(search_space) {
imap(search_space$domains, function(token, .id) {
converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn
if (!is.function(converter)) {
stopf("No converter exists for InternalTuneToken of parameters '%s'", .id)
stopf("No converter exists for parameter '%s'", .id)
}
converter(token)
})
Expand Down
5 changes: 3 additions & 2 deletions man/Domain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 15 additions & 6 deletions man/ParamSet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/ParamSetCollection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 3 additions & 6 deletions tests/testthat/test_ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -451,18 +451,15 @@ test_that("aggr", {
expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "but there are no")
})

test_that("convert_internal_tune_tokens", {
test_that("convert_internal_search_space", {
param_set = ps(
a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper,
aggr = function(x) round(mean(unlist(x))), disable_in_tune = list(a = 1))
)
param_set$set_values(a = to_tune(internal = TRUE))
expect_identical(param_set$convert_internal_tune_tokens(), list(a = 100))
expect_identical(param_set$convert_internal_search_space(param_set$search_space()), list(a = 100))
param_set$set_values(a = to_tune(internal = TRUE, upper = 99))
expect_identical(param_set$convert_internal_tune_tokens(), list(a = 99))

param_set$set_values(a = to_tune(internal = FALSE))
expect_identical(param_set$convert_internal_tune_tokens(), named_list())
expect_identical(param_set$convert_internal_search_space(param_set$search_space()), list(a = 99))
})

test_that("get_values works with internal_tune", {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,11 @@ test_that("internal and aggr", {
# range + internal
param_set$set_values(a = to_tune(lower = 1.2, upper = 1.3, aggr = function(x) 1.5))
expect_equal(param_set$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5))
expect_equal(param_set$convert_internal_tune_tokens(), list(a = 1.3))
expect_equal(param_set$convert_internal_search_space(param_set$search_space()), list(a = 1.3))

# full + internal
param_set$set_values(a = to_tune(internal = TRUE, aggr = function(x) 1.5))
expect_equal(param_set$convert_internal_tune_tokens(), list(a = 2))
expect_equal(param_set$convert_internal_search_space(param_set$search_space()), list(a = 2))

# domain + internal
expect_error(
Expand Down

0 comments on commit 1b391f7

Please sign in to comment.