Skip to content

Fix format for single prediction (for multinom_reg() and prediction type "prob") #612

Closed
@JB304245

Description

@JB304245

When calling stats::predict() on a multinom model (engine = "nnet"), the output format differs between a scalar input and an input of length > 1. I tried searching a bit whether this is intentional, but could not find anything. So im not sure whether this is a bug or a feature.

I personally like the way it behaves for inputs of length > 1, where the column names represent the labels. When predicting on a scalar it is not clear what labels the probabilities correspond to. I assume one could get them from model_trained$lvl, but I would be nervous whether the order is guaranteed to be the same, especially since this model is going to go to production in my use case.

So my suggestion is to always use the output format that is currently used when inputting a vector.

Here is a reproducible example:

library(parsnip)
library(magrittr)
library(nnet)

possible_outcomes = paste0("V", 1:10)

df = data.frame(training_target = as.factor(rep(possible_outcomes, 1000)))
df$predictor = rnorm(nrow(df), 0.5)

set.seed(123)

model = parsnip::multinom_reg(mode = "classification") %>%
  parsnip::set_engine("nnet")

model_trained = model %>%
  parsnip::fit(training_target ~ predictor,
               data = df)

out1 = stats::predict(model_trained,
                      new_data = data.frame(predictor = c(0.6)),
                      type = 'prob')

out2 = stats::predict(model_trained,
                      new_data = data.frame(predictor = c(0.6, 0.3)),
                      type = 'prob')

> out1
# A tibble: 10 × 1
   .pred_value
         <dbl>
 1      0.100 
 2      0.100 
 3      0.100 
 4      0.0999
 5      0.0992
 6      0.0998
 7      0.100 
 8      0.0992
 9      0.100 
10      0.101 

> out2
# A tibble: 2 × 10
  .pred_V1 .pred_V10 .pred_V2 .pred_V3 .pred_V4 .pred_V5 .pred_V6 .pred_V7 .pred_V8 .pred_V9
     <dbl>     <dbl>    <dbl>    <dbl>    <dbl>    <dbl>    <dbl>    <dbl>    <dbl>    <dbl>
1   0.100      0.100   0.100    0.0999   0.0992   0.0998   0.100    0.0992    0.100   0.101 
2   0.0992     0.100   0.0993   0.100    0.101    0.100    0.0994   0.101     0.100   0.0988

Metadata

Metadata

Assignees

Labels

bugan unexpected problem or unintended behaviornext release 🚀

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions