@@ -15,15 +15,17 @@ make_classes <- function(prefix) {
15
15
# ' @return If an error is not thrown (from non-empty ellipses), a NULL list.
16
16
# ' @keywords internal
17
17
# ' @export
18
- check_empty_ellipse <- function (... ) {
18
+ check_empty_ellipse <- function (... ) {
19
19
terms <- quos(... )
20
- if (! is_empty(terms ))
20
+ if (! is_empty(terms )) {
21
21
rlang :: abort(" Please pass other arguments to the model function via `set_engine()`." )
22
+ }
22
23
terms
23
24
}
24
25
25
- is_missing_arg <- function (x )
26
+ is_missing_arg <- function (x ) {
26
27
identical(x , quote(missing_arg()))
28
+ }
27
29
28
30
model_info_table <-
29
31
utils :: read.delim(system.file(" models.tsv" , package = " parsnip" ))
@@ -38,7 +40,11 @@ has_loaded_implementation <- function(spec_, engine_, mode_) {
38
40
if (isFALSE(mode_ %in% c(" regression" , " censored regression" , " classification" ))) {
39
41
mode_ <- c(" regression" , " censored regression" , " classification" )
40
42
}
41
- eng_cond <- if (is.null(engine_ )) {TRUE } else {quote(engine == engine_ )}
43
+ eng_cond <- if (is.null(engine_ )) {
44
+ TRUE
45
+ } else {
46
+ quote(engine == engine_ )
47
+ }
42
48
43
49
avail <-
44
50
get_from_env(spec_ ) %> %
@@ -56,7 +62,7 @@ has_loaded_implementation <- function(spec_, engine_, mode_) {
56
62
57
63
is_printable_spec <- function (x ) {
58
64
! is.null(x $ method $ fit $ args ) &&
59
- has_loaded_implementation(class(x )[1 ], x $ engine , x $ mode )
65
+ has_loaded_implementation(class(x )[1 ], x $ engine , x $ mode )
60
66
}
61
67
62
68
# construct a message informing the user that there are no
@@ -67,31 +73,52 @@ is_printable_spec <- function(x) {
67
73
inform_missing_implementation <- function (spec_ , engine_ , mode_ ) {
68
74
avail <-
69
75
show_engines(spec_ ) %> %
70
- dplyr :: filter(mode == mode_ , engine == engine_ )
76
+ dplyr :: filter(engine == engine_ )
71
77
all <-
72
78
model_info_table %> %
73
- dplyr :: filter(model == spec_ , mode == mode_ , engine == engine_ , ! is.na(pkg )) %> %
79
+ dplyr :: filter(model == spec_ , engine == engine_ , ! is.na(pkg )) %> %
74
80
dplyr :: select(- model )
75
81
76
- if (identical(mode_ , " unknown" )) {
77
- mode_ <- " "
82
+ if (! identical(mode_ , " unknown" )) {
83
+ avail <- avail %> % dplyr :: filter(mode == mode_ )
84
+ all <- all %> % dplyr :: filter(mode == mode_ )
85
+ msg <- glue :: glue(
86
+ paste0(
87
+ " parsnip could not locate an implementation for {spec_} {mode_} " ,
88
+ " model specifications using the `{engine_}` engine. "
89
+ )
90
+ )
91
+ } else {
92
+ msg <- glue :: glue(
93
+ paste0(
94
+ " parsnip could not locate an implementation for {spec_} " ,
95
+ " model specifications using the `{engine_}` engine. "
96
+ )
97
+ )
78
98
}
79
99
80
- msg <-
81
- glue :: glue(
82
- " parsnip could not locate an implementation for `{spec_}` {mode_} model \\
83
- specifications using the `{engine_}` engine."
84
- )
85
100
86
101
if (nrow(avail ) == 0 && nrow(all ) > 0 ) {
87
- msg <-
88
- c(
89
- msg ,
90
- i = paste0(" The parsnip extension package " , all $ pkg [[1 ]],
91
- " implements support for this specification." ),
92
- i = " Please install (if needed) and load to continue." ,
93
- " "
102
+ if (nrow(all ) == 1 ) {
103
+ msg <-
104
+ c(
105
+ i = paste0(
106
+ msg ,
107
+ glue :: glue(" Please install `{all$pkg[[1]]}` (if needed) and load to continue." ),
108
+ " \n "
109
+ )
110
+ )
111
+ } else {
112
+ msg <- c(
113
+ i = paste0(
114
+ msg ,
115
+ " The following parsnip extension packages " ,
116
+ " implement support for this specification."
117
+ ),
118
+ i = paste0(unique(all $ pkg ), collapse = " , " ),
119
+ i = " Please install one of them (if needed) and load to continue.\n "
94
120
)
121
+ }
95
122
}
96
123
97
124
msg
@@ -109,22 +136,25 @@ show_call <- function(object) {
109
136
map(object $ method $ fit $ args , convert_arg )
110
137
111
138
call2(object $ method $ fit $ func [" fun" ],
112
- !!! object $ method $ fit $ args ,
113
- .ns = object $ method $ fit $ func [" pkg" ])
139
+ !!! object $ method $ fit $ args ,
140
+ .ns = object $ method $ fit $ func [" pkg" ]
141
+ )
114
142
}
115
143
116
144
convert_arg <- function (x ) {
117
- if (is_quosure(x ))
145
+ if (is_quosure(x )) {
118
146
quo_get_expr(x )
119
- else
147
+ } else {
120
148
x
149
+ }
121
150
}
122
151
123
152
levels_from_formula <- function (f , dat ) {
124
- if (inherits(dat , " tbl_spark" ))
153
+ if (inherits(dat , " tbl_spark" )) {
125
154
res <- NULL
126
- else
155
+ } else {
127
156
res <- levels(eval_tidy(f [[2 ]], dat ))
157
+ }
128
158
res
129
159
}
130
160
@@ -134,7 +164,7 @@ levels_from_formula <- function(f, dat) {
134
164
show_fit <- function (model , eng ) {
135
165
mod <- translate(x = model , engine = eng )
136
166
fit_call <- show_call(mod )
137
- call_text <- deparse(fit_call )
167
+ call_text <- deparse(fit_call )
138
168
call_text <- paste0(call_text , collapse = " \n " )
139
169
paste0(
140
170
" \\ preformatted{\n " ,
@@ -157,9 +187,10 @@ check_args.default <- function(object) {
157
187
158
188
# copied form recipes
159
189
160
- names0 <- function (num , prefix = " x" ) {
161
- if (num < 1 )
190
+ names0 <- function (num , prefix = " x" ) {
191
+ if (num < 1 ) {
162
192
rlang :: abort(" `num` should be > 0." )
193
+ }
163
194
ind <- format(1 : num )
164
195
ind <- gsub(" " , " 0" , ind )
165
196
paste0(prefix , ind )
@@ -172,16 +203,16 @@ names0 <- function (num, prefix = "x") {
172
203
# ' @keywords internal
173
204
# ' @rdname add_on_exports
174
205
update_dot_check <- function (... ) {
175
-
176
206
dots <- enquos(... )
177
207
178
- if (length(dots ) > 0 )
208
+ if (length(dots ) > 0 ) {
179
209
rlang :: abort(
180
210
glue :: glue(
181
211
" Extra arguments will be ignored: " ,
182
212
glue :: glue_collapse(glue :: glue(" `{names(dots)}`" ), sep = " , " )
183
213
)
184
214
)
215
+ }
185
216
invisible (NULL )
186
217
}
187
218
@@ -192,15 +223,16 @@ update_dot_check <- function(...) {
192
223
# ' @rdname add_on_exports
193
224
new_model_spec <- function (cls , args , eng_args , mode , method , engine ,
194
225
check_missing_spec = TRUE ) {
195
-
196
226
check_spec_mode_engine_val(cls , engine , mode )
197
227
198
228
if ((! has_loaded_implementation(cls , engine , mode )) && check_missing_spec ) {
199
229
rlang :: inform(inform_missing_implementation(cls , engine , mode ))
200
230
}
201
231
202
- out <- list (args = args , eng_args = eng_args ,
203
- mode = mode , method = method , engine = engine )
232
+ out <- list (
233
+ args = args , eng_args = eng_args ,
234
+ mode = mode , method = method , engine = engine
235
+ )
204
236
class(out ) <- make_classes(cls )
205
237
out
206
238
}
@@ -211,8 +243,9 @@ check_outcome <- function(y, spec) {
211
243
if (spec $ mode == " unknown" ) {
212
244
return (invisible (NULL ))
213
245
} else if (spec $ mode == " regression" ) {
214
- if (! all(map_lgl(y , is.numeric )))
246
+ if (! all(map_lgl(y , is.numeric ))) {
215
247
rlang :: abort(" For a regression model, the outcome should be numeric." )
248
+ }
216
249
} else if (spec $ mode == " classification" ) {
217
250
if (! all(map_lgl(y , is.factor ))) {
218
251
rlang :: abort(" For a classification model, the outcome should be a factor." )
@@ -250,7 +283,6 @@ check_final_param <- function(x) {
250
283
# ' @keywords internal
251
284
# ' @rdname add_on_exports
252
285
update_main_parameters <- function (args , param ) {
253
-
254
286
if (length(param ) == 0 ) {
255
287
return (args )
256
288
}
@@ -263,8 +295,10 @@ update_main_parameters <- function(args, param) {
263
295
extra_args <- names(param )[has_extra_args ]
264
296
if (any(has_extra_args )) {
265
297
rlang :: abort(
266
- paste(" At least one argument is not a main argument:" ,
267
- paste0(" `" , extra_args , " `" , collapse = " , " ))
298
+ paste(
299
+ " At least one argument is not a main argument:" ,
300
+ paste0(" `" , extra_args , " `" , collapse = " , " )
301
+ )
268
302
)
269
303
}
270
304
param <- param [! has_extra_args ]
@@ -276,7 +310,6 @@ update_main_parameters <- function(args, param) {
276
310
# ' @keywords internal
277
311
# ' @rdname add_on_exports
278
312
update_engine_parameters <- function (eng_args , fresh , ... ) {
279
-
280
313
dots <- enquos(... )
281
314
282
315
# # only update from dots when there are eng args in original model spec
@@ -303,16 +336,20 @@ update_engine_parameters <- function(eng_args, fresh, ...) {
303
336
stan_conf_int <- function (object , newdata ) {
304
337
check_installs(list (method = list (libs = " rstanarm" )))
305
338
if (utils :: packageVersion(" rstanarm" ) > = " 2.21.1" ) {
306
- fn <- rlang :: call2(" posterior_epred" , .ns = " rstanarm" ,
307
- object = expr(object ),
308
- newdata = expr(newdata ),
309
- seed = expr(sample.int(10 ^ 5 , 1 )))
339
+ fn <- rlang :: call2(" posterior_epred" ,
340
+ .ns = " rstanarm" ,
341
+ object = expr(object ),
342
+ newdata = expr(newdata ),
343
+ seed = expr(sample.int(10 ^ 5 , 1 ))
344
+ )
310
345
} else {
311
- fn <- rlang :: call2(" posterior_linpred" , .ns = " rstanarm" ,
312
- object = expr(object ),
313
- newdata = expr(newdata ),
314
- transform = TRUE ,
315
- seed = expr(sample.int(10 ^ 5 , 1 )))
346
+ fn <- rlang :: call2(" posterior_linpred" ,
347
+ .ns = " rstanarm" ,
348
+ object = expr(object ),
349
+ newdata = expr(newdata ),
350
+ transform = TRUE ,
351
+ seed = expr(sample.int(10 ^ 5 , 1 ))
352
+ )
316
353
}
317
354
rlang :: eval_tidy(fn )
318
355
}
@@ -357,30 +394,31 @@ stan_conf_int <- function(object, newdata) {
357
394
# ' @keywords internal
358
395
# ' @export
359
396
.check_glmnet_penalty_predict <- function (penalty = NULL , object , multi = FALSE ) {
360
-
361
397
if (is.null(penalty )) {
362
398
penalty <- object $ fit $ lambda
363
399
}
364
400
365
401
# when using `predict()`, allow for a single lambda
366
402
if (! multi ) {
367
- if (length(penalty ) != 1 )
403
+ if (length(penalty ) != 1 ) {
368
404
rlang :: abort(
369
405
glue :: glue(
370
406
" `penalty` should be a single numeric value. `multi_predict()` " ,
371
407
" can be used to get multiple predictions per row of data." ,
372
408
)
373
409
)
410
+ }
374
411
}
375
412
376
- if (length(object $ fit $ lambda ) == 1 && penalty != object $ fit $ lambda )
413
+ if (length(object $ fit $ lambda ) == 1 && penalty != object $ fit $ lambda ) {
377
414
rlang :: abort(
378
415
glue :: glue(
379
416
" The glmnet model was fit with a single penalty value of " ,
380
417
" {object$fit$lambda}. Predicting with a value of {penalty} " ,
381
418
" will give incorrect results from `glmnet()`."
382
419
)
383
420
)
421
+ }
384
422
385
423
penalty
386
424
}
0 commit comments