|
21 | 21 | #' @param activation A single character string denoting the type of relationship
|
22 | 22 | #' between the original predictors and the hidden unit layer. The activation
|
23 | 23 | #' function between the hidden and output layers is automatically set to either
|
24 |
| -#' "linear" or "softmax" depending on the type of outcome. Possible values are: |
25 |
| -#' "linear", "softmax", "relu", and "elu" |
| 24 | +#' "linear" or "softmax" depending on the type of outcome. Possible values |
| 25 | +#' depend on the engine being used. |
26 | 26 | #'
|
27 | 27 | #' @templateVar modeltype mlp
|
28 | 28 | #' @template spec-details
|
@@ -142,24 +142,6 @@ check_args.mlp <- function(object) {
|
142 | 142 | if (args$dropout > 0 & args$penalty > 0)
|
143 | 143 | rlang::abort("Both weight decay and dropout should not be specified.")
|
144 | 144 |
|
145 |
| - |
146 |
| - if (object$engine == "brulee") { |
147 |
| - act_funs <- c("linear", "relu", "elu", "tanh") |
148 |
| - } else if (object$engine == "keras") { |
149 |
| - act_funs <- c("linear", "softmax", "relu", "elu") |
150 |
| - } else if (object$engine == "h2o") { |
151 |
| - act_funs <- c("relu", "tanh") |
152 |
| - } |
153 |
| - |
154 |
| - if (is.character(args$activation)) { |
155 |
| - if (!any(args$activation %in% c(act_funs))) { |
156 |
| - rlang::abort( |
157 |
| - glue::glue("`activation` should be one of: ", |
158 |
| - glue::glue_collapse(glue::glue("'{act_funs}'"), sep = ", ")) |
159 |
| - ) |
160 |
| - } |
161 |
| - } |
162 |
| - |
163 | 145 | invisible(object)
|
164 | 146 | }
|
165 | 147 |
|
@@ -210,6 +192,9 @@ keras_mlp <-
|
210 | 192 | seeds = sample.int(10^5, size = 3),
|
211 | 193 | ...) {
|
212 | 194 |
|
| 195 | + act_funs <- c("linear", "softmax", "relu", "elu") |
| 196 | + rlang::arg_match(activation, act_funs,) |
| 197 | + |
213 | 198 | if (penalty > 0 & dropout > 0) {
|
214 | 199 | rlang::abort("Please use either dropoput or weight decay.", call. = FALSE)
|
215 | 200 | }
|
|
0 commit comments