Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Apr 16, 2024
1 parent 8fc9f5f commit cc5c828
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 26 deletions.
3 changes: 1 addition & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# dev

* feat: added `aggr`(egation function) to `Domain` which can be used for inner
tuning.
* feat: added support for `aggr`(egation function) which can be used for inner tuning.

# paradox 0.12.0
* Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes:
Expand Down
3 changes: 2 additions & 1 deletion R/Design.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Design = R6Class("Design",
# set fixed param vals to their constant values
# FIXME: this might also be problematic for LHS
# do we still create an LHS like this?
imap(param_set$values, function(v, n) set(data, j = n, value = v))

imap(param_set$values, function(v, n) {set(data, j = n, value = list(v))})
self$data = data
if (param_set$has_deps) {
private$set_deps_to_na()
Expand Down
6 changes: 1 addition & 5 deletions R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
#' @param aggr (`function`)\cr
#' Function with one argument, which is a list of parameter values.
#' The function specifies how a list of parameter values is aggregated to form one parameter value.
#' This is used in the context of inner tuning. The default is to aggregate the values.
#' This is used in the context of inner tuning. The default is to aggregate the values and round up.
#'
#' @return A `Domain` object.
#'
Expand Down Expand Up @@ -153,10 +153,6 @@ Domain = function(cls, grouping,
assert_function(trafo, null.ok = TRUE)
assert_function(aggr, null.ok = TRUE, nargs = 1L)

if (is.null(aggr) && "inner_tuning" %in% tags) {
aggr = default_aggr
}

# depends may be an expression, but may also be quote() or expression()
if (length(depends_expr) == 1) {
depends_expr = eval(depends_expr, envir = parent.frame(2))
Expand Down
10 changes: 8 additions & 2 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ ParamSet = R6Class("ParamSet",
if (".requirements" %in% names(paramtbl)) {
requirements = paramtbl$.requirements
private$.params = paramtbl # self$add_dep needs this
for (row in seq_len(nrow(paramtbl))) {
for (row in seq_len(nrow(paramtbl))) {
for (req in requirements[[row]]) {
invoke(self$add_dep, id = paramtbl$id[[row]], allow_dangling_dependencies = allow_dangling_dependencies,
.args = req)
Expand Down Expand Up @@ -276,6 +276,7 @@ ParamSet = R6Class("ParamSet",
if (nrow(private$.aggrs) && !length(x[[1L]])) {
stopf("More than one value is required to aggregate them")
}

imap(x, function(value, .id) {
aggr = private$.aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value)
})
Expand Down Expand Up @@ -528,6 +529,7 @@ ParamSet = R6Class("ParamSet",
.trafo = private$.trafos[id, trafo],
.requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps
.init_given = id %in% names(vals),
.aggr = private$.aggrs[id, get("aggr")],
.init = unname(vals[id]))
]

Expand Down Expand Up @@ -562,6 +564,7 @@ ParamSet = R6Class("ParamSet",

result$.__enclos_env__$private$.params = setindexv(private$.params[ids, on = "id"], c("id", "cls", "grouping"))
result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[ids, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[ids, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.tags = setkeyv(private$.tags[ids, on = "id", nomatch = NULL], "id")
result$assert_values = FALSE
result$deps = deps[ids, on = "id", nomatch = NULL]
Expand Down Expand Up @@ -589,6 +592,7 @@ ParamSet = R6Class("ParamSet",
result$.__enclos_env__$private$.params = setindexv(private$.params[get_id, on = "id"], c("id", "cls", "grouping"))
# setkeyv not strictly necessary since get_id is scalar, but we do it for consistency
result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[get_id, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[get_id, on = "id", nomatch = NULL], "id")
result$.__enclos_env__$private$.tags = setkeyv(private$.tags[get_id, on = "id", nomatch = NULL], "id")
result$assert_values = FALSE
result$values = values[match(get_id, names(values), nomatch = 0)]
Expand Down Expand Up @@ -740,6 +744,7 @@ ParamSet = R6Class("ParamSet",
result = copy(private$.params)
result[, .tags := list(self$tags)]
result[private$.trafos, .trafo := list(trafo), on = "id"]
result[private$.aggrs, .aggr := list(aggr), on = "id"]
result[self$deps, .requirements := transpose_list(.(on, cond)), on = "id"]
vals = self$values
result[, `:=`(
Expand Down Expand Up @@ -904,13 +909,14 @@ ParamSet = R6Class("ParamSet",
values = keep(values, inherits, "TuneToken")
if (!length(values)) return(ParamSet$new())
params = map(names(values), function(pn) {
domain = private$.params[pn, on = "id"]
domain = self$params[pn, on = "id"]
set_class(domain, c(domain$cls, "Domain", class(domain)))
})
names(params) = names(values)

# package-internal S3 fails if we don't call the function indirectly here
partsets = pmap(list(values, params), function(...) tunetoken_to_ps(...))

pars = ps_union(partsets) # partsets does not have names here, wihch is what we want.

names(partsets) = names(values)
Expand Down
5 changes: 5 additions & 0 deletions R/ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
private$.tags = paramtbl[, .(tag = unique(unlist(.tags))), keyby = "id"]

private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id")
private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id")

private$.translation = paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE]
setkeyv(private$.translation, "id")
Expand Down Expand Up @@ -125,6 +126,10 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
if (nrow(newtrafos)) {
private$.trafos = setkeyv(rbind(private$.trafos, newtrafos), "id")
}
newaggrs = paramtbl[!map_lgl(.aggr, is.null), .(id, trafo = .aggr)]
if (nrow(newaggrs)) {
private$.aggrs = setkeyv(rbind(private$.aggrs, newaggrs), "id")
}

private$.translation = rbind(private$.translation, paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE])
setkeyv(private$.translation, "id")
Expand Down
21 changes: 12 additions & 9 deletions R/to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,9 @@ to_tune = function(...) {
#' The default is to average the values and round them up.
#' @export
in_tune = function(..., aggr = NULL) {
if (is.null(aggr)) {
aggr = default_aggr
} else {
test_function(aggr, nargs = 1L)
}
test_function(aggr, nargs = 1L, null.ok = TRUE)
tt = to_tune(...)
tt$aggr = aggr
if (!is.null(aggr)) tt$content$aggr = aggr
tt = set_class(tt, classes = c("InnerTuneToken", class(tt)))
return(tt)
}
Expand Down Expand Up @@ -239,7 +235,11 @@ tunetoken_to_ps = function(tt, param) {
UseMethod("tunetoken_to_ps")
}

tunetoken_to_ps.InnerTuneToken = function(tt, params) {
tunetoken_to_ps.InnerTuneToken = function(tt, param) {
tt$content$aggr = tt$content$aggr %??% param$.aggr
if (is.null(tt$content$aggr)) {
stopf("%s (%s): Provide an aggregation function for inner tuning.", tt$call, param$id)
}
ps = NextMethod()
ps$tags = map(ps$tags, function(tags) union(tags, "inner_tuning"))
return(ps)
Expand All @@ -251,7 +251,7 @@ tunetoken_to_ps.FullTuneToken = function(tt, param) {
}
if (isTRUE(tt$content$logscale)) {
if (!domain_is_number(param)) stop("%s (%s): logscale only valid for numeric / integer parameters.", tt$call, param$id)
tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale), tt$call), param)
tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param)
} else {
pslike_to_ps(param, tt$call, param)
}
Expand All @@ -264,6 +264,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) {
}
invalidpoints = discard(tt$content, function(x) is.null(x) || domain_test(param, set_names(list(x), param$id)))
invalidpoints$logscale = NULL
invalidpoints$aggr = NULL
if (length(invalidpoints)) {
stopf("%s range not compatible with param %s.\nBad value(s):\n%s\nParameter:\n%s",
tt$call, param$id, repr(invalidpoints), repr(param))
Expand All @@ -279,7 +280,9 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) {
# create p_int / p_dbl object. Doesn't work if there is a numeric param class that we don't know about :-/
constructor = switch(param$cls, ParamInt = p_int, ParamDbl = p_dbl,
stopf("%s: logscale for parameter %s of class %s not supported", tt$call, param$id, param$class))
content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale)
content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale,
aggr = tt$content$aggr)

pslike_to_ps(content, tt$call, param)
}

Expand Down
2 changes: 1 addition & 1 deletion man/Domain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test_ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,14 @@ test_that("aggr", {
expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation")
expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "More than one")
})

test_that("in_tune", {
param_set = ps(a = p_dbl(lower = 1, upper = 2))
param_set$set_values(
a = in_tune(lower = 1, upper = 2, aggr = function(x) 1.5)
)

ss = param_set$search_space()

ss$aggr(list(a = list(1, 2)))
})
13 changes: 8 additions & 5 deletions tests/testthat/test_domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,15 @@ test_that("$extra_trafo flag works", {
test_that("in_tune", {
it = in_tune(1)
expect_class(it, "InnerTuneToken")
expect_function(it$aggr)
tt = to_tune(1)
expect_null(it$aggr)
tt = in_tune(1)
expect_equal(it$content, tt$content)
expect_equal(it$aggr(list(1, 2)), 2L)

it1 = in_tune(aggr = function(x) min(unlist(x)))
expect_equal(it1$aggr(list(1, 2)), 1)
expect_true("inner_tuning" %in% ps(a = p_dbl(1, 10))$set_values(a = in_tune())$search_space()$tags)
expect_equal(it1$content$aggr(list(1, 2)), 1)
param_set = ps(
a = p_dbl(1, 10, aggr = default_aggr)
)
param_set$set_values(a = in_tune())
expect_class(param_set$values$a, "InnerTuneToken")
})
2 changes: 1 addition & 1 deletion tests/testthat/test_to_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -396,5 +396,5 @@ test_that("logscale in tunetoken", {
expect_output(print(to_tune(lower = 1, logscale = TRUE)), "range \\[1, \\.\\.\\.] \\(log scale\\)")
expect_output(print(to_tune(upper = 1, logscale = TRUE)), "range \\[\\.\\.\\., 1] \\(log scale\\)")
expect_output(print(to_tune(lower = 0, upper = 1, logscale = TRUE)), "range \\[0, 1] \\(log scale\\)")

expect_output(print(in_tune()), "Inner")
})

0 comments on commit cc5c828

Please sign in to comment.