Skip to content

Commit 086ad63

Browse files
authored
add standard error for interval predictions (#978)
1 parent 3bf1da9 commit 086ad63

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
* `augment()` now works for censored regression models.
66

7+
* For BART models with the `dbarts` engine, `predict()` can now also return the standard error for confidence and prediction intervals (#976).
8+
9+
710
# parsnip 1.1.0
811

912
This release of parsnip contains a number of new features and bug fixes, accompanied by several optimizations that substantially decrease the time to `fit()` and `predict()` with the package.

R/bart.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALS
232232
)
233233
)
234234
}
235+
if (std_err) {
236+
res$.std_error <- apply(post_dist, 2, stats::sd, na.rm = TRUE)
237+
}
235238
}
236239
res
237240
}

R/bart_data.R

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ set_pred(
140140
obj = expr(object),
141141
new_data = expr(new_data),
142142
type = "conf_int",
143-
level = expr(level)
143+
level = expr(level),
144+
std_err = expr(std_error)
144145
)
145146
)
146147
)
@@ -158,7 +159,8 @@ set_pred(
158159
obj = expr(object),
159160
new_data = expr(new_data),
160161
type = "pred_int",
161-
level = expr(level)
162+
level = expr(level),
163+
std_err = expr(std_error)
162164
)
163165
)
164166
)
@@ -215,7 +217,8 @@ set_pred(
215217
obj = expr(object),
216218
new_data = expr(new_data),
217219
type = "conf_int",
218-
level = expr(level)
220+
level = expr(level),
221+
std_err = expr(std_error)
219222
)
220223
)
221224
)
@@ -233,7 +236,8 @@ set_pred(
233236
obj = expr(object),
234237
new_data = expr(new_data),
235238
type = "pred_int",
236-
level = expr(level)
239+
level = expr(level),
240+
std_err = expr(std_error)
237241
)
238242
)
239243
)
@@ -248,7 +252,9 @@ set_pred(
248252
post = NULL,
249253
func = c(pkg = "parsnip", fun = "dbart_predict_calc"),
250254
args =
251-
list(obj = quote(object),
252-
new_data = quote(new_data))
255+
list(
256+
obj = quote(object),
257+
new_data = quote(new_data)
258+
)
253259
)
254260
)

0 commit comments

Comments
 (0)