Skip to content

Commit 525baee

Browse files
topepo‘topepo’
and
‘topepo’
authored
Activation checks (#1065)
* remove mlp activation check for #1019 * unit test for #1019 * update news * version requirement in skip --------- Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’>
1 parent 25fc2c6 commit 525baee

File tree

3 files changed

+32
-20
lines changed

3 files changed

+32
-20
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* parsnip now lets the engines for [mlp()] check for acceptable values of the activation function (#1019)
4+
35
* Tightened logic for outcome checking. This resolves issues—some errors and some silent failures—when atomic outcome variables have an attribute (#1060, #1061).
46

57
* `rpart_train()` has been deprecated in favor of using `decision_tree()` with the `"rpart"` engine or `rpart::rpart()` directly (#1044).

R/mlp.R

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
#' @param activation A single character string denoting the type of relationship
2222
#' between the original predictors and the hidden unit layer. The activation
2323
#' function between the hidden and output layers is automatically set to either
24-
#' "linear" or "softmax" depending on the type of outcome. Possible values are:
25-
#' "linear", "softmax", "relu", and "elu"
24+
#' "linear" or "softmax" depending on the type of outcome. Possible values
25+
#' depend on the engine being used.
2626
#'
2727
#' @templateVar modeltype mlp
2828
#' @template spec-details
@@ -142,24 +142,6 @@ check_args.mlp <- function(object) {
142142
if (args$dropout > 0 & args$penalty > 0)
143143
rlang::abort("Both weight decay and dropout should not be specified.")
144144

145-
146-
if (object$engine == "brulee") {
147-
act_funs <- c("linear", "relu", "elu", "tanh")
148-
} else if (object$engine == "keras") {
149-
act_funs <- c("linear", "softmax", "relu", "elu")
150-
} else if (object$engine == "h2o") {
151-
act_funs <- c("relu", "tanh")
152-
}
153-
154-
if (is.character(args$activation)) {
155-
if (!any(args$activation %in% c(act_funs))) {
156-
rlang::abort(
157-
glue::glue("`activation` should be one of: ",
158-
glue::glue_collapse(glue::glue("'{act_funs}'"), sep = ", "))
159-
)
160-
}
161-
}
162-
163145
invisible(object)
164146
}
165147

@@ -210,6 +192,9 @@ keras_mlp <-
210192
seeds = sample.int(10^5, size = 3),
211193
...) {
212194

195+
act_funs <- c("linear", "softmax", "relu", "elu")
196+
rlang::arg_match(activation, act_funs,)
197+
213198
if (penalty > 0 & dropout > 0) {
214199
rlang::abort("Please use either dropoput or weight decay.", call. = FALSE)
215200
}

tests/testthat/test_mlp.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,28 @@ test_that("nnet_softmax", {
2323
expect_equal(res$b, 1 - res$a)
2424
})
2525

26+
test_that("more activations for brulee", {
27+
skip_if_not_installed("brulee", minimum_version = "0.3.0")
28+
skip_on_cran()
29+
30+
data(ames, package = "modeldata")
31+
32+
ames$Sale_Price <- log10(ames$Sale_Price)
33+
34+
set.seed(122)
35+
in_train <- sample(1:nrow(ames), 2000)
36+
ames_train <- ames[ in_train,]
37+
ames_test <- ames[-in_train,]
38+
39+
set.seed(1)
40+
fit <-
41+
try(
42+
mlp(penalty = 0.10, activation = "softplus") %>%
43+
set_mode("regression") %>%
44+
set_engine("brulee") %>%
45+
fit_xy(x = as.matrix(ames_train[, c("Longitude", "Latitude")]),
46+
y = ames_train$Sale_Price),
47+
silent = TRUE)
48+
expect_true(inherits(fit$fit, "brulee_mlp"))
49+
})
50+

0 commit comments

Comments
 (0)