Closed
Description
The problem
I'm having trouble with getting the standard error for the predictions from a parsnip::bart() model in the same way that can be done for a stan model.
I think this should be possible and is a bug rather than a feature request.
Thanks!
Reproducible example
# Load packages and prepare data ------------------------------------------
# Code adapted from: https://parsnip.tidymodels.org/articles/Examples.html
library(tidymodels)
library(dbarts)
#>
#> Attaching package: 'dbarts'
#> The following object is masked from 'package:tidyr':
#>
#> extract
#> The following object is masked from 'package:parsnip':
#>
#> bart
library(rstanarm)
#> Loading required package: Rcpp
#>
#> Attaching package: 'Rcpp'
#> The following object is masked from 'package:rsample':
#>
#> populate
#> This is rstanarm version 2.21.4
#> - See https://mc-stan.org/rstanarm/articles/priors for changes to default priors!
#> - Default priors may change, so it's safest to specify priors, even if equivalent to the defaults.
#> - For execution on a local, multicore CPU with excess RAM we recommend calling
#> options(mc.cores = parallel::detectCores())
tidymodels_prefer()
# Data
data(two_class_dat)
data_train <- two_class_dat[-(1:10), ]
data_test <- two_class_dat[ 1:10 , ]
# BART model --------------------------------------------------------------
# Example to show that BART model does not include the standard error for the
# predictions
# BART model specification and fit
set.seed(1)
bt_cls_fit <-
parsnip::bart() %>%
set_mode("classification") %>%
set_engine("dbarts") %>%
fit(Class ~ ., data = data_train)
# Make predictions - output does not include the .std_error column
set.seed(2)
bind_cols(
predict(bt_cls_fit, data_test, type = "prob"),
predict(bt_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 2
#> .pred_Class1 .pred_Class2
#> <dbl> <dbl>
#> 1 0.344 0.656
#> 2 0.82 0.18
#> 3 0.562 0.438
#> 4 0.608 0.392
#> 5 0.438 0.562
#> 6 0.234 0.766
#> 7 0.632 0.368
#> 8 0.448 0.552
#> 9 0.971 0.029
#> 10 0.0780 0.922
# Check to see if the information is contained in the dbarts model fit object to
# calculate the standard error for the predictions (i.e. check to make sure that
# calling std_error = TRUE in predict can return the standard error). It looks
# like the information is there but I'm not 100% sure that the below is correct.
# Extract dbarts fit
bt_cls_eng <-
extract_fit_engine(bt_cls_fit)
# Make predictions
set.seed(2)
bt_cls_pred <-
predict(bt_cls_eng, data_test)
# Summarise posterior predictions for each obeservation
tibble(
.pred_class1 = 1 - apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
.pred_class2 = apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
.std_error = apply(bt_cls_pred, 2, stats::sd, na.rm = TRUE)
)
#> # A tibble: 10 × 3
#> .pred_class1 .pred_class2 .std_error
#> <dbl> <dbl> <dbl>
#> 1 0.335 0.665 0.0947
#> 2 0.830 0.170 0.0743
#> 3 0.586 0.414 0.0969
#> 4 0.606 0.394 0.123
#> 5 0.434 0.566 0.104
#> 6 0.231 0.769 0.0876
#> 7 0.649 0.351 0.111
#> 8 0.448 0.552 0.110
#> 9 0.977 0.0228 0.0265
#> 10 0.0785 0.922 0.0486
# Stan model --------------------------------------------------------------
# Example to show that a Stan model does include the standard error for the
# predictions (what I'm hoping the bart model can provide).
# Stan model specification and fit
set.seed(1)
logreg_cls_fit <-
logistic_reg() %>%
set_engine("stan") %>%
fit(Class ~ ., data = data_train)
# Make predictions - output includes the .std_error column
bind_cols(
predict(logreg_cls_fit, data_test, type = "prob"),
predict(logreg_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 3
#> .pred_Class1 .pred_Class2 .std_error
#> <dbl> <dbl> <dbl>
#> 1 0.518 0.482 0.500
#> 2 0.909 0.0909 0.287
#> 3 0.650 0.350 0.474
#> 4 0.609 0.391 0.491
#> 5 0.443 0.557 0.497
#> 6 0.206 0.794 0.402
#> 7 0.708 0.292 0.454
#> 8 0.568 0.432 0.497
#> 9 0.994 0.00580 0.0834
#> 10 0.108 0.892 0.313
Created on 2023-05-29 with reprex v2.0.2
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.2.3 (2023-03-15 ucrt)
#> os Windows 10 x64 (build 19042)
#> system x86_64, mingw32
#> ui RTerm
#> language (EN)
#> collate English_Australia.utf8
#> ctype English_Australia.utf8
#> tz Australia/Perth
#> date 2023-05-29
#> pandoc 2.19.2 @ C:/program files/rstudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> ! package * version date (UTC) lib source
#> backports 1.4.1 2021-12-13 [2] CRAN (R 4.2.0)
#> base64enc 0.1-3 2015-07-28 [2] CRAN (R 4.2.0)
#> bayesplot 1.10.0 2022-11-16 [1] CRAN (R 4.2.3)
#> boot 1.3-28.1 2022-11-22 [2] CRAN (R 4.2.3)
#> broom * 1.0.4 2023-03-11 [2] CRAN (R 4.2.3)
#> cachem 1.0.8 2023-05-01 [1] CRAN (R 4.2.3)
#> callr 3.7.3 2022-11-02 [2] CRAN (R 4.2.3)
#> class 7.3-21 2023-01-23 [2] CRAN (R 4.2.3)
#> cli 3.6.1 2023-03-23 [1] CRAN (R 4.2.3)
#> codetools 0.2-19 2023-02-01 [2] CRAN (R 4.2.3)
#> colorspace 2.1-0 2023-01-23 [2] CRAN (R 4.2.3)
#> colourpicker 1.2.0 2022-10-28 [1] CRAN (R 4.2.3)
#> conflicted 1.2.0 2023-02-01 [2] CRAN (R 4.2.3)
#> crayon 1.5.2 2022-09-29 [2] CRAN (R 4.2.3)
#> crosstalk 1.2.0 2021-11-04 [2] CRAN (R 4.2.3)
#> curl 5.0.0 2023-01-12 [2] CRAN (R 4.2.3)
#> data.table 1.14.8 2023-02-17 [2] CRAN (R 4.2.3)
#> dbarts * 0.9-23 2023-01-23 [1] CRAN (R 4.2.3)
#> dials * 1.2.0 2023-04-03 [2] CRAN (R 4.2.3)
#> DiceDesign 1.9 2021-02-13 [2] CRAN (R 4.2.3)
#> digest 0.6.31 2022-12-11 [2] CRAN (R 4.2.3)
#> dplyr * 1.1.2 2023-04-20 [1] CRAN (R 4.2.3)
#> DT 0.27 2023-01-17 [1] CRAN (R 4.2.3)
#> dygraphs 1.1.1.6 2018-07-11 [1] CRAN (R 4.2.3)
#> ellipsis 0.3.2 2021-04-29 [2] CRAN (R 4.2.3)
#> evaluate 0.21 2023-05-05 [1] CRAN (R 4.2.3)
#> fansi 1.0.4 2023-01-22 [1] CRAN (R 4.2.2)
#> fastmap 1.1.1 2023-02-24 [2] CRAN (R 4.2.3)
#> foreach 1.5.2 2022-02-02 [2] CRAN (R 4.2.3)
#> fs 1.6.2 2023-04-25 [1] CRAN (R 4.2.3)
#> furrr 0.3.1 2022-08-15 [2] CRAN (R 4.2.3)
#> future 1.32.0 2023-03-07 [2] CRAN (R 4.2.3)
#> future.apply 1.10.0 2022-11-05 [2] CRAN (R 4.2.3)
#> generics 0.1.3 2022-07-05 [2] CRAN (R 4.2.3)
#> ggplot2 * 3.4.2 2023-04-03 [1] CRAN (R 4.2.3)
#> globals 0.16.2 2022-11-21 [2] CRAN (R 4.2.2)
#> glue 1.6.2 2022-02-24 [2] CRAN (R 4.2.3)
#> gower 1.0.1 2022-12-22 [2] CRAN (R 4.2.2)
#> GPfit 1.0-8 2019-02-08 [2] CRAN (R 4.2.3)
#> gridExtra 2.3 2017-09-09 [2] CRAN (R 4.2.3)
#> gtable 0.3.3 2023-03-21 [2] CRAN (R 4.2.3)
#> gtools 3.9.4 2022-11-27 [1] CRAN (R 4.2.3)
#> hardhat 1.3.0 2023-03-30 [2] CRAN (R 4.2.3)
#> htmltools 0.5.5 2023-03-23 [2] CRAN (R 4.2.3)
#> htmlwidgets 1.6.2 2023-03-17 [2] CRAN (R 4.2.3)
#> httpuv 1.6.11 2023-05-11 [1] CRAN (R 4.2.3)
#> igraph 1.4.3 2023-05-22 [1] CRAN (R 4.2.3)
#> infer * 1.0.4 2022-12-02 [2] CRAN (R 4.2.3)
#> inline 0.3.19 2021-05-31 [1] CRAN (R 4.2.3)
#> ipred 0.9-14 2023-03-09 [2] CRAN (R 4.2.3)
#> iterators 1.0.14 2022-02-05 [2] CRAN (R 4.2.3)
#> jsonlite 1.8.4 2022-12-06 [1] CRAN (R 4.2.2)
#> knitr 1.42 2023-01-25 [2] CRAN (R 4.2.3)
#> later 1.3.1 2023-05-02 [1] CRAN (R 4.2.3)
#> lattice 0.20-45 2021-09-22 [2] CRAN (R 4.2.3)
#> lava 1.7.2.1 2023-02-27 [2] CRAN (R 4.2.3)
#> lhs 1.1.6 2022-12-17 [2] CRAN (R 4.2.3)
#> lifecycle 1.0.3 2022-10-07 [2] CRAN (R 4.2.3)
#> listenv 0.9.0 2022-12-16 [2] CRAN (R 4.2.3)
#> lme4 1.1-33 2023-04-25 [1] CRAN (R 4.2.3)
#> loo 2.6.0 2023-03-31 [1] CRAN (R 4.2.3)
#> lubridate 1.9.2 2023-02-10 [1] CRAN (R 4.2.2)
#> magrittr 2.0.3 2022-03-30 [2] CRAN (R 4.2.3)
#> markdown 1.7 2023-05-16 [1] CRAN (R 4.2.3)
#> MASS 7.3-58.2 2023-01-23 [2] CRAN (R 4.2.3)
#> Matrix 1.5-3 2022-11-11 [1] CRAN (R 4.2.2)
#> matrixStats 0.63.0 2022-11-18 [1] CRAN (R 4.2.3)
#> memoise 2.0.1 2021-11-26 [2] CRAN (R 4.2.3)
#> mime 0.12 2021-09-28 [2] CRAN (R 4.2.0)
#> miniUI 0.1.1.1 2018-05-18 [2] CRAN (R 4.2.3)
#> minqa 1.2.5 2022-10-19 [2] CRAN (R 4.2.3)
#> modeldata * 1.1.0 2023-01-25 [2] CRAN (R 4.2.3)
#> munsell 0.5.0 2018-06-12 [2] CRAN (R 4.2.3)
#> nlme 3.1-162 2023-01-31 [2] CRAN (R 4.2.3)
#> nloptr 2.0.3 2022-05-26 [2] CRAN (R 4.2.3)
#> nnet 7.3-18 2022-09-28 [2] CRAN (R 4.2.3)
#> parallelly 1.35.0 2023-03-23 [2] CRAN (R 4.2.3)
#> parsnip * 1.1.0 2023-04-12 [2] CRAN (R 4.2.3)
#> pillar 1.9.0 2023-03-22 [2] CRAN (R 4.2.3)
#> pkgbuild 1.4.0 2022-11-27 [2] CRAN (R 4.2.3)
#> pkgconfig 2.0.3 2019-09-22 [2] CRAN (R 4.2.3)
#> plyr 1.8.8 2022-11-11 [2] CRAN (R 4.2.3)
#> prettyunits 1.1.1 2020-01-24 [2] CRAN (R 4.2.3)
#> processx 3.8.1 2023-04-18 [1] CRAN (R 4.2.3)
#> prodlim 2023.03.31 2023-04-02 [2] CRAN (R 4.2.3)
#> promises 1.2.0.1 2021-02-11 [2] CRAN (R 4.2.3)
#> ps 1.7.5 2023-04-18 [1] CRAN (R 4.2.3)
#> purrr * 1.0.1 2023-01-10 [1] CRAN (R 4.2.2)
#> R6 2.5.1 2021-08-19 [2] CRAN (R 4.2.3)
#> Rcpp * 1.0.10 2023-01-22 [1] CRAN (R 4.2.2)
#> D RcppParallel 5.1.7 2023-02-27 [1] CRAN (R 4.2.3)
#> recipes * 1.0.6 2023-04-25 [2] CRAN (R 4.2.3)
#> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.3)
#> reshape2 1.4.4 2020-04-09 [1] CRAN (R 4.2.3)
#> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.2.3)
#> rmarkdown 2.21 2023-03-26 [2] CRAN (R 4.2.3)
#> rpart 4.1.19 2022-10-21 [2] CRAN (R 4.2.3)
#> rsample * 1.1.1 2022-12-07 [2] CRAN (R 4.2.3)
#> rstan 2.26.15 2023-02-11 [1] local
#> rstanarm * 2.21.4 2023-04-08 [1] CRAN (R 4.2.3)
#> rstantools 2.3.1 2023-03-30 [1] CRAN (R 4.2.3)
#> rstudioapi 0.14 2022-08-22 [2] CRAN (R 4.2.3)
#> scales * 1.2.1 2022-08-20 [2] CRAN (R 4.2.3)
#> sessioninfo 1.2.2 2021-12-06 [2] CRAN (R 4.2.3)
#> shiny 1.7.4 2022-12-15 [2] CRAN (R 4.2.3)
#> shinyjs 2.1.0 2021-12-23 [1] CRAN (R 4.2.3)
#> shinystan 2.6.0 2022-03-03 [1] CRAN (R 4.2.3)
#> shinythemes 1.2.0 2021-01-25 [1] CRAN (R 4.2.3)
#> StanHeaders 2.26.15 2023-02-11 [1] local
#> stringi 1.7.12 2023-01-11 [1] CRAN (R 4.2.2)
#> stringr 1.5.0 2022-12-02 [1] CRAN (R 4.2.2)
#> survival 3.5-3 2023-02-12 [2] CRAN (R 4.2.3)
#> threejs 0.3.3 2020-01-21 [1] CRAN (R 4.2.3)
#> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.2.3)
#> tidymodels * 1.1.0 2023-05-01 [1] CRAN (R 4.2.3)
#> tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.2.2)
#> tidyselect 1.2.0 2022-10-10 [2] CRAN (R 4.2.3)
#> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.2.2)
#> timeDate 4022.108 2023-01-07 [2] CRAN (R 4.2.3)
#> tune * 1.1.1 2023-04-11 [2] CRAN (R 4.2.3)
#> utf8 1.2.3 2023-01-31 [1] CRAN (R 4.2.2)
#> V8 4.3.0 2023-04-08 [1] CRAN (R 4.2.3)
#> vctrs 0.6.2 2023-04-19 [1] CRAN (R 4.2.3)
#> withr 2.5.0 2022-03-03 [2] CRAN (R 4.2.3)
#> workflows * 1.1.3 2023-02-22 [2] CRAN (R 4.2.3)
#> workflowsets * 1.0.1 2023-04-06 [2] CRAN (R 4.2.3)
#> xfun 0.39 2023-04-20 [1] CRAN (R 4.2.3)
#> xtable 1.8-4 2019-04-21 [2] CRAN (R 4.2.3)
#> xts 0.13.1 2023-04-16 [1] CRAN (R 4.2.3)
#> yaml 2.3.7 2023-01-23 [2] CRAN (R 4.2.3)
#> yardstick * 1.2.0 2023-04-21 [2] CRAN (R 4.2.3)
#> zoo 1.8-12 2023-04-13 [1] CRAN (R 4.2.3)
#>
#> [1] C:/Users/00055815/AppData/Local/R/win-library/4.2
#> [2] C:/Program Files/R/R-4.2.3/library
#>
#> D ── DLL MD5 mismatch, broken installation.
#>
#> ──────────────────────────────────────────────────────────────────────────────