Skip to content

Commit 3bf1da9

Browse files
authored
check for mode first (#972)
because regression and classification models also have `$censor_probs` even though its just an empty list
1 parent d104fa1 commit 3bf1da9

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.1.0.9002
3+
Version: 1.1.0.9003
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),

R/survival-censoring-weights.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ trunc_probs <- function(probs, trunc = 0.01) {
6363
}
6464

6565
.check_censor_model <- function(x) {
66+
if (x$spec$mode != "censored regression") {
67+
cli::cli_abort(
68+
"The model needs to be for mode 'censored regression', not for mode '{x$spec$mode}'."
69+
)
70+
}
6671
nms <- names(x)
6772
if (!any(nms == "censor_probs")) {
6873
rlang::abort("Please refit the model with parsnip version 1.0.4 or greater.")
@@ -245,14 +250,17 @@ add_graf_weights_vec <- function(object, .pred, surv_obj, trunc = 0.05, eps = 10
245250
num_times <- vctrs::list_sizes(.pred)
246251
y <- vctrs::list_unchop(.pred)
247252
y$surv_obj <- vctrs::vec_rep_each(surv_obj, times = num_times)
253+
248254
names(y)[names(y) == ".time"] <- ".eval_time" # Temporary
255+
249256
# Compute the actual time of evaluation
250257
y$.weight_time <- graf_weight_time_vec(y$surv_obj, y$.eval_time, eps = eps)
251258
# Compute the corresponding probability of being censored
252259
y$.pred_censored <- predict(object$censor_probs, time = y$.weight_time, as_vector = TRUE)
253260
y$.pred_censored <- trunc_probs(y$.pred_censored, trunc = trunc)
254261
# Invert the probabilities to create weights
255262
y$.weight_censored = 1 / y$.pred_censored
263+
256264
# Convert back the list column format
257265
y$surv_obj <- NULL
258266
vctrs::vec_chop(y, sizes = num_times)

0 commit comments

Comments
 (0)