Skip to content

bart model predictions do not provide the standard error for the predictions #976

Closed
@jdberson

Description

@jdberson

The problem

I'm having trouble with getting the standard error for the predictions from a parsnip::bart() model in the same way that can be done for a stan model.

I think this should be possible and is a bug rather than a feature request.

Thanks!

Reproducible example

# Load packages and prepare data ------------------------------------------

# Code adapted from: https://parsnip.tidymodels.org/articles/Examples.html

library(tidymodels)
library(dbarts)
#> 
#> Attaching package: 'dbarts'
#> The following object is masked from 'package:tidyr':
#> 
#>     extract
#> The following object is masked from 'package:parsnip':
#> 
#>     bart
library(rstanarm)
#> Loading required package: Rcpp
#> 
#> Attaching package: 'Rcpp'
#> The following object is masked from 'package:rsample':
#> 
#>     populate
#> This is rstanarm version 2.21.4
#> - See https://mc-stan.org/rstanarm/articles/priors for changes to default priors!
#> - Default priors may change, so it's safest to specify priors, even if equivalent to the defaults.
#> - For execution on a local, multicore CPU with excess RAM we recommend calling
#>   options(mc.cores = parallel::detectCores())

tidymodels_prefer()

# Data
data(two_class_dat)
data_train <- two_class_dat[-(1:10), ]
data_test  <- two_class_dat[  1:10 , ]


# BART model --------------------------------------------------------------

# Example to show that BART model does not include the standard error for the
# predictions

# BART model specification and fit
set.seed(1)
bt_cls_fit <- 
  parsnip::bart() %>% 
  set_mode("classification") %>% 
  set_engine("dbarts") %>% 
  fit(Class ~ ., data = data_train)

# Make predictions - output does not include the .std_error column
set.seed(2)
bind_cols(
  predict(bt_cls_fit, data_test, type = "prob"),
  predict(bt_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
  select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 2
#>    .pred_Class1 .pred_Class2
#>           <dbl>        <dbl>
#>  1       0.344         0.656
#>  2       0.82          0.18 
#>  3       0.562         0.438
#>  4       0.608         0.392
#>  5       0.438         0.562
#>  6       0.234         0.766
#>  7       0.632         0.368
#>  8       0.448         0.552
#>  9       0.971         0.029
#> 10       0.0780        0.922

# Check to see if the information is contained in the dbarts model fit object to
# calculate the standard error for the predictions (i.e. check to make sure that
# calling std_error = TRUE in predict can return the standard error). It looks 
# like the information is there but I'm not 100% sure that the below is correct.

# Extract dbarts fit
bt_cls_eng <- 
  extract_fit_engine(bt_cls_fit)

# Make predictions
set.seed(2)
bt_cls_pred <- 
  predict(bt_cls_eng, data_test)

# Summarise posterior predictions for each obeservation
tibble(
  .pred_class1 = 1 - apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
  .pred_class2 = apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
  .std_error = apply(bt_cls_pred, 2, stats::sd, na.rm = TRUE)
)
#> # A tibble: 10 × 3
#>    .pred_class1 .pred_class2 .std_error
#>           <dbl>        <dbl>      <dbl>
#>  1       0.335        0.665      0.0947
#>  2       0.830        0.170      0.0743
#>  3       0.586        0.414      0.0969
#>  4       0.606        0.394      0.123 
#>  5       0.434        0.566      0.104 
#>  6       0.231        0.769      0.0876
#>  7       0.649        0.351      0.111 
#>  8       0.448        0.552      0.110 
#>  9       0.977        0.0228     0.0265
#> 10       0.0785       0.922      0.0486


# Stan model --------------------------------------------------------------

# Example to show that a Stan model does include the standard error for the 
# predictions (what I'm hoping the bart model can provide).

# Stan model specification and fit
set.seed(1)
logreg_cls_fit <- 
  logistic_reg() %>% 
  set_engine("stan") %>% 
  fit(Class ~ ., data = data_train)

# Make predictions - output includes the .std_error column
bind_cols(
  predict(logreg_cls_fit, data_test, type = "prob"),
  predict(logreg_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
  select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 3
#>    .pred_Class1 .pred_Class2 .std_error
#>           <dbl>        <dbl>      <dbl>
#>  1        0.518      0.482       0.500 
#>  2        0.909      0.0909      0.287 
#>  3        0.650      0.350       0.474 
#>  4        0.609      0.391       0.491 
#>  5        0.443      0.557       0.497 
#>  6        0.206      0.794       0.402 
#>  7        0.708      0.292       0.454 
#>  8        0.568      0.432       0.497 
#>  9        0.994      0.00580     0.0834
#> 10        0.108      0.892       0.313

Created on 2023-05-29 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.3 (2023-03-15 ucrt)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_Australia.utf8
#>  ctype    English_Australia.utf8
#>  tz       Australia/Perth
#>  date     2023-05-29
#>  pandoc   2.19.2 @ C:/program files/rstudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package      * version    date (UTC) lib source
#>    backports      1.4.1      2021-12-13 [2] CRAN (R 4.2.0)
#>    base64enc      0.1-3      2015-07-28 [2] CRAN (R 4.2.0)
#>    bayesplot      1.10.0     2022-11-16 [1] CRAN (R 4.2.3)
#>    boot           1.3-28.1   2022-11-22 [2] CRAN (R 4.2.3)
#>    broom        * 1.0.4      2023-03-11 [2] CRAN (R 4.2.3)
#>    cachem         1.0.8      2023-05-01 [1] CRAN (R 4.2.3)
#>    callr          3.7.3      2022-11-02 [2] CRAN (R 4.2.3)
#>    class          7.3-21     2023-01-23 [2] CRAN (R 4.2.3)
#>    cli            3.6.1      2023-03-23 [1] CRAN (R 4.2.3)
#>    codetools      0.2-19     2023-02-01 [2] CRAN (R 4.2.3)
#>    colorspace     2.1-0      2023-01-23 [2] CRAN (R 4.2.3)
#>    colourpicker   1.2.0      2022-10-28 [1] CRAN (R 4.2.3)
#>    conflicted     1.2.0      2023-02-01 [2] CRAN (R 4.2.3)
#>    crayon         1.5.2      2022-09-29 [2] CRAN (R 4.2.3)
#>    crosstalk      1.2.0      2021-11-04 [2] CRAN (R 4.2.3)
#>    curl           5.0.0      2023-01-12 [2] CRAN (R 4.2.3)
#>    data.table     1.14.8     2023-02-17 [2] CRAN (R 4.2.3)
#>    dbarts       * 0.9-23     2023-01-23 [1] CRAN (R 4.2.3)
#>    dials        * 1.2.0      2023-04-03 [2] CRAN (R 4.2.3)
#>    DiceDesign     1.9        2021-02-13 [2] CRAN (R 4.2.3)
#>    digest         0.6.31     2022-12-11 [2] CRAN (R 4.2.3)
#>    dplyr        * 1.1.2      2023-04-20 [1] CRAN (R 4.2.3)
#>    DT             0.27       2023-01-17 [1] CRAN (R 4.2.3)
#>    dygraphs       1.1.1.6    2018-07-11 [1] CRAN (R 4.2.3)
#>    ellipsis       0.3.2      2021-04-29 [2] CRAN (R 4.2.3)
#>    evaluate       0.21       2023-05-05 [1] CRAN (R 4.2.3)
#>    fansi          1.0.4      2023-01-22 [1] CRAN (R 4.2.2)
#>    fastmap        1.1.1      2023-02-24 [2] CRAN (R 4.2.3)
#>    foreach        1.5.2      2022-02-02 [2] CRAN (R 4.2.3)
#>    fs             1.6.2      2023-04-25 [1] CRAN (R 4.2.3)
#>    furrr          0.3.1      2022-08-15 [2] CRAN (R 4.2.3)
#>    future         1.32.0     2023-03-07 [2] CRAN (R 4.2.3)
#>    future.apply   1.10.0     2022-11-05 [2] CRAN (R 4.2.3)
#>    generics       0.1.3      2022-07-05 [2] CRAN (R 4.2.3)
#>    ggplot2      * 3.4.2      2023-04-03 [1] CRAN (R 4.2.3)
#>    globals        0.16.2     2022-11-21 [2] CRAN (R 4.2.2)
#>    glue           1.6.2      2022-02-24 [2] CRAN (R 4.2.3)
#>    gower          1.0.1      2022-12-22 [2] CRAN (R 4.2.2)
#>    GPfit          1.0-8      2019-02-08 [2] CRAN (R 4.2.3)
#>    gridExtra      2.3        2017-09-09 [2] CRAN (R 4.2.3)
#>    gtable         0.3.3      2023-03-21 [2] CRAN (R 4.2.3)
#>    gtools         3.9.4      2022-11-27 [1] CRAN (R 4.2.3)
#>    hardhat        1.3.0      2023-03-30 [2] CRAN (R 4.2.3)
#>    htmltools      0.5.5      2023-03-23 [2] CRAN (R 4.2.3)
#>    htmlwidgets    1.6.2      2023-03-17 [2] CRAN (R 4.2.3)
#>    httpuv         1.6.11     2023-05-11 [1] CRAN (R 4.2.3)
#>    igraph         1.4.3      2023-05-22 [1] CRAN (R 4.2.3)
#>    infer        * 1.0.4      2022-12-02 [2] CRAN (R 4.2.3)
#>    inline         0.3.19     2021-05-31 [1] CRAN (R 4.2.3)
#>    ipred          0.9-14     2023-03-09 [2] CRAN (R 4.2.3)
#>    iterators      1.0.14     2022-02-05 [2] CRAN (R 4.2.3)
#>    jsonlite       1.8.4      2022-12-06 [1] CRAN (R 4.2.2)
#>    knitr          1.42       2023-01-25 [2] CRAN (R 4.2.3)
#>    later          1.3.1      2023-05-02 [1] CRAN (R 4.2.3)
#>    lattice        0.20-45    2021-09-22 [2] CRAN (R 4.2.3)
#>    lava           1.7.2.1    2023-02-27 [2] CRAN (R 4.2.3)
#>    lhs            1.1.6      2022-12-17 [2] CRAN (R 4.2.3)
#>    lifecycle      1.0.3      2022-10-07 [2] CRAN (R 4.2.3)
#>    listenv        0.9.0      2022-12-16 [2] CRAN (R 4.2.3)
#>    lme4           1.1-33     2023-04-25 [1] CRAN (R 4.2.3)
#>    loo            2.6.0      2023-03-31 [1] CRAN (R 4.2.3)
#>    lubridate      1.9.2      2023-02-10 [1] CRAN (R 4.2.2)
#>    magrittr       2.0.3      2022-03-30 [2] CRAN (R 4.2.3)
#>    markdown       1.7        2023-05-16 [1] CRAN (R 4.2.3)
#>    MASS           7.3-58.2   2023-01-23 [2] CRAN (R 4.2.3)
#>    Matrix         1.5-3      2022-11-11 [1] CRAN (R 4.2.2)
#>    matrixStats    0.63.0     2022-11-18 [1] CRAN (R 4.2.3)
#>    memoise        2.0.1      2021-11-26 [2] CRAN (R 4.2.3)
#>    mime           0.12       2021-09-28 [2] CRAN (R 4.2.0)
#>    miniUI         0.1.1.1    2018-05-18 [2] CRAN (R 4.2.3)
#>    minqa          1.2.5      2022-10-19 [2] CRAN (R 4.2.3)
#>    modeldata    * 1.1.0      2023-01-25 [2] CRAN (R 4.2.3)
#>    munsell        0.5.0      2018-06-12 [2] CRAN (R 4.2.3)
#>    nlme           3.1-162    2023-01-31 [2] CRAN (R 4.2.3)
#>    nloptr         2.0.3      2022-05-26 [2] CRAN (R 4.2.3)
#>    nnet           7.3-18     2022-09-28 [2] CRAN (R 4.2.3)
#>    parallelly     1.35.0     2023-03-23 [2] CRAN (R 4.2.3)
#>    parsnip      * 1.1.0      2023-04-12 [2] CRAN (R 4.2.3)
#>    pillar         1.9.0      2023-03-22 [2] CRAN (R 4.2.3)
#>    pkgbuild       1.4.0      2022-11-27 [2] CRAN (R 4.2.3)
#>    pkgconfig      2.0.3      2019-09-22 [2] CRAN (R 4.2.3)
#>    plyr           1.8.8      2022-11-11 [2] CRAN (R 4.2.3)
#>    prettyunits    1.1.1      2020-01-24 [2] CRAN (R 4.2.3)
#>    processx       3.8.1      2023-04-18 [1] CRAN (R 4.2.3)
#>    prodlim        2023.03.31 2023-04-02 [2] CRAN (R 4.2.3)
#>    promises       1.2.0.1    2021-02-11 [2] CRAN (R 4.2.3)
#>    ps             1.7.5      2023-04-18 [1] CRAN (R 4.2.3)
#>    purrr        * 1.0.1      2023-01-10 [1] CRAN (R 4.2.2)
#>    R6             2.5.1      2021-08-19 [2] CRAN (R 4.2.3)
#>    Rcpp         * 1.0.10     2023-01-22 [1] CRAN (R 4.2.2)
#>  D RcppParallel   5.1.7      2023-02-27 [1] CRAN (R 4.2.3)
#>    recipes      * 1.0.6      2023-04-25 [2] CRAN (R 4.2.3)
#>    reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.3)
#>    reshape2       1.4.4      2020-04-09 [1] CRAN (R 4.2.3)
#>    rlang          1.1.1      2023-04-28 [1] CRAN (R 4.2.3)
#>    rmarkdown      2.21       2023-03-26 [2] CRAN (R 4.2.3)
#>    rpart          4.1.19     2022-10-21 [2] CRAN (R 4.2.3)
#>    rsample      * 1.1.1      2022-12-07 [2] CRAN (R 4.2.3)
#>    rstan          2.26.15    2023-02-11 [1] local
#>    rstanarm     * 2.21.4     2023-04-08 [1] CRAN (R 4.2.3)
#>    rstantools     2.3.1      2023-03-30 [1] CRAN (R 4.2.3)
#>    rstudioapi     0.14       2022-08-22 [2] CRAN (R 4.2.3)
#>    scales       * 1.2.1      2022-08-20 [2] CRAN (R 4.2.3)
#>    sessioninfo    1.2.2      2021-12-06 [2] CRAN (R 4.2.3)
#>    shiny          1.7.4      2022-12-15 [2] CRAN (R 4.2.3)
#>    shinyjs        2.1.0      2021-12-23 [1] CRAN (R 4.2.3)
#>    shinystan      2.6.0      2022-03-03 [1] CRAN (R 4.2.3)
#>    shinythemes    1.2.0      2021-01-25 [1] CRAN (R 4.2.3)
#>    StanHeaders    2.26.15    2023-02-11 [1] local
#>    stringi        1.7.12     2023-01-11 [1] CRAN (R 4.2.2)
#>    stringr        1.5.0      2022-12-02 [1] CRAN (R 4.2.2)
#>    survival       3.5-3      2023-02-12 [2] CRAN (R 4.2.3)
#>    threejs        0.3.3      2020-01-21 [1] CRAN (R 4.2.3)
#>    tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.2.3)
#>    tidymodels   * 1.1.0      2023-05-01 [1] CRAN (R 4.2.3)
#>    tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.2.2)
#>    tidyselect     1.2.0      2022-10-10 [2] CRAN (R 4.2.3)
#>    timechange     0.2.0      2023-01-11 [1] CRAN (R 4.2.2)
#>    timeDate       4022.108   2023-01-07 [2] CRAN (R 4.2.3)
#>    tune         * 1.1.1      2023-04-11 [2] CRAN (R 4.2.3)
#>    utf8           1.2.3      2023-01-31 [1] CRAN (R 4.2.2)
#>    V8             4.3.0      2023-04-08 [1] CRAN (R 4.2.3)
#>    vctrs          0.6.2      2023-04-19 [1] CRAN (R 4.2.3)
#>    withr          2.5.0      2022-03-03 [2] CRAN (R 4.2.3)
#>    workflows    * 1.1.3      2023-02-22 [2] CRAN (R 4.2.3)
#>    workflowsets * 1.0.1      2023-04-06 [2] CRAN (R 4.2.3)
#>    xfun           0.39       2023-04-20 [1] CRAN (R 4.2.3)
#>    xtable         1.8-4      2019-04-21 [2] CRAN (R 4.2.3)
#>    xts            0.13.1     2023-04-16 [1] CRAN (R 4.2.3)
#>    yaml           2.3.7      2023-01-23 [2] CRAN (R 4.2.3)
#>    yardstick    * 1.2.0      2023-04-21 [2] CRAN (R 4.2.3)
#>    zoo            1.8-12     2023-04-13 [1] CRAN (R 4.2.3)
#> 
#>  [1] C:/Users/00055815/AppData/Local/R/win-library/4.2
#>  [2] C:/Program Files/R/R-4.2.3/library
#> 
#>  D ── DLL MD5 mismatch, broken installation.
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Metadata

Metadata

Assignees

No one assigned

    Labels

    featurea feature request or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions