Closed
Description
predict()
produces a factor and multi_predict()
is character:
library(tidymodels)
#> Registered S3 method overwritten by 'xts':
#> method from
#> as.zoo.xts zoo
#> ── Attaching packages ───────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.3 ──
#> ✔ broom 0.5.2 ✔ purrr 0.3.3
#> ✔ dials 0.0.3.9001 ✔ recipes 0.1.7.9001
#> ✔ dplyr 0.8.3 ✔ rsample 0.0.5
#> ✔ ggplot2 3.2.1 ✔ tibble 2.1.3
#> ✔ infer 0.5.0 ✔ yardstick 0.0.4
#> ✔ parsnip 0.0.3.9001
#> ── Conflicts ──────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ ggplot2::margin() masks dials::margin()
#> ✖ dials::offset() masks stats::offset()
#> ✖ recipes::step() masks stats::step()
library(tune)
library(glmnet)
#> Loading required package: Matrix
#>
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#>
#> expand, pack, unpack
#> Loaded glmnet 3.0
library(mlbench)
data("Satellite")
mod <- multinom_reg() %>%
set_engine("glmnet")
fit <- mod %>% fit(classes ~ ., data = Satellite[-(1:10),])
predict(fit, new_data = Satellite[1:10, -37], penalty = .01)
#> # A tibble: 10 x 1
#> .pred_class
#> <fct>
#> 1 grey soil
#> 2 grey soil
#> 3 grey soil
#> 4 grey soil
#> 5 grey soil
#> 6 grey soil
#> 7 grey soil
#> 8 grey soil
#> 9 damp grey soil
#> 10 damp grey soil
multi_predict(fit, new_data = Satellite[1:10, -37], penalty = c(.1, 1))$.pred[[1]]
#> # A tibble: 2 x 2
#> .pred_class penalty
#> <chr> <dbl>
#> 1 grey soil 0.1
#> 2 red soil 1
Created on 2019-10-23 by the reprex package (v0.3.0)
Metadata
Metadata
Assignees
Labels
No labels