Skip to content

Commit b64f30f

Browse files
authored
time to eval_time (#936)
* reformat * rename `time` to `eval_time`
1 parent 34f74fe commit b64f30f

File tree

6 files changed

+101
-60
lines changed

6 files changed

+101
-60
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.4.9004
3+
Version: 1.0.4.9005
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),

R/predict.R

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#' linear predictors). Default value is `FALSE`.
3232
#' \item `quantile`: for `type` equal to `quantile`, the quantiles of the
3333
#' distribution. Default is `(1:9)/10`.
34-
#' \item `time`: for `type` equal to `"survival"` or `"hazard"`, the
34+
#' \item `eval_time`: for `type` equal to `"survival"` or `"hazard"`, the
3535
#' time points at which the survival probability or hazard is estimated.
3636
#' }
3737
#' @details For `type = NULL`, `predict()` uses
@@ -48,7 +48,7 @@
4848
#'
4949
#' ## Censored regression predictions
5050
#'
51-
#' For censored regression, a numeric vector for `time` is required when
51+
#' For censored regression, a numeric vector for `eval_time` is required when
5252
#' survival or hazard probabilities are requested. Also, when
5353
#' `type = "linear_pred"`, censored regression models will by default be
5454
#' formatted such that the linear predictor _increases_ with time. This may
@@ -83,11 +83,11 @@
8383
#'
8484
#' For `type = "survival"`, the tibble has a `.pred` column, which is
8585
#' a list-column. Each list element contains a tibble with columns
86-
#' `.time` and `.pred_survival` (and perhaps other columns).
86+
#' `.eval_time` and `.pred_survival` (and perhaps other columns).
8787
#'
8888
#' For `type = "hazard"`, the tibble has a `.pred` column, which is
8989
#' a list-column. Each list element contains a tibble with columns
90-
#' `.time` and `.pred_hazard` (and perhaps other columns).
90+
#' `.eval_time` and `.pred_hazard` (and perhaps other columns).
9191
#'
9292
#' Using `type = "raw"` with `predict.model_fit()` will return
9393
#' the unadulterated results of the prediction function.
@@ -334,7 +334,8 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
334334

335335
# ----------------------------------------------------------------------------
336336

337-
other_args <- c("interval", "level", "std_error", "quantile", "time", "increasing")
337+
other_args <- c("interval", "level", "std_error", "quantile",
338+
"time", "eval_time", "increasing")
338339
is_pred_arg <- names(the_dots) %in% other_args
339340
if (any(!is_pred_arg)) {
340341
bad_args <- names(the_dots)[!is_pred_arg]
@@ -348,7 +349,15 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
348349
}
349350

350351
# ----------------------------------------------------------------------------
351-
# places where time should not be given
352+
# places where eval_time should not be given
353+
if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) {
354+
rlang::abort(
355+
paste(
356+
"`eval_time` should only be passed to `predict()` when `type` is one of:",
357+
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
358+
)
359+
)
360+
}
352361
if (any(nms == "time") & !type %in% c("survival", "hazard")) {
353362
rlang::abort(
354363
paste(
@@ -357,12 +366,12 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
357366
)
358367
)
359368
}
360-
# when time should be passed
361-
if (!any(nms == "time") & type %in% c("survival", "hazard")) {
369+
# when eval_time should be passed
370+
if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) {
362371
rlang::abort(
363372
paste(
364-
"When using 'type' values of 'survival' or 'hazard' are given,",
365-
"a numeric vector 'time' should also be given."
373+
"When using `type` values of 'survival' or 'hazard',",
374+
"a numeric vector `eval_time` should also be given."
366375
)
367376
)
368377
}

R/predict_hazard.R

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,47 @@
44
#' @method predict_hazard model_fit
55
#' @export predict_hazard.model_fit
66
#' @export
7-
predict_hazard.model_fit <-
8-
function(object, new_data, time, ...) {
9-
10-
check_spec_pred_type(object, "hazard")
7+
predict_hazard.model_fit <- function(object,
8+
new_data,
9+
eval_time,
10+
time = deprecated(),
11+
...) {
12+
if (lifecycle::is_present(time)) {
13+
lifecycle::deprecate_warn(
14+
"1.0.4.9005",
15+
"predict_hazard(time)",
16+
"predict_hazard(eval_time)"
17+
)
18+
eval_time <- time
19+
}
1120

12-
if (inherits(object$fit, "try-error")) {
13-
rlang::warn("Model fit failed; cannot make predictions.")
14-
return(NULL)
15-
}
21+
check_spec_pred_type(object, "hazard")
1622

17-
new_data <- prepare_data(object, new_data)
23+
if (inherits(object$fit, "try-error")) {
24+
rlang::warn("Model fit failed; cannot make predictions.")
25+
return(NULL)
26+
}
1827

19-
# preprocess data
20-
if (!is.null(object$spec$method$pred$hazard$pre))
21-
new_data <- object$spec$method$pred$hazard$pre(new_data, object)
28+
new_data <- prepare_data(object, new_data)
2229

23-
# Pass some extra arguments to be used in post-processor
24-
object$spec$method$pred$hazard$args$time <- time
25-
pred_call <- make_pred_call(object$spec$method$pred$hazard)
30+
# preprocess data
31+
if (!is.null(object$spec$method$pred$hazard$pre))
32+
new_data <- object$spec$method$pred$hazard$pre(new_data, object)
2633

27-
res <- eval_tidy(pred_call)
34+
# Pass some extra arguments to be used in post-processor
35+
object$spec$method$pred$hazard$args$eval_time <- eval_time
36+
pred_call <- make_pred_call(object$spec$method$pred$hazard)
2837

29-
# post-process the predictions
30-
if(!is.null(object$spec$method$pred$hazard$post)) {
31-
res <- object$spec$method$pred$hazard$post(res, object)
32-
}
38+
res <- eval_tidy(pred_call)
3339

34-
res
40+
# post-process the predictions
41+
if(!is.null(object$spec$method$pred$hazard$post)) {
42+
res <- object$spec$method$pred$hazard$post(res, object)
3543
}
3644

45+
res
46+
}
47+
3748
# @export
3849
# @keywords internal
3950
# @rdname other_predict

R/predict_survival.R

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,49 @@
44
#' @method predict_survival model_fit
55
#' @export predict_survival.model_fit
66
#' @export
7-
predict_survival.model_fit <-
8-
function(object, new_data, time, interval = "none", level = 0.95, ...) {
9-
10-
check_spec_pred_type(object, "survival")
7+
predict_survival.model_fit <- function(object,
8+
new_data,
9+
eval_time,
10+
time = deprecated(),
11+
interval = "none",
12+
level = 0.95,
13+
...) {
14+
if (lifecycle::is_present(time)) {
15+
lifecycle::deprecate_warn(
16+
"1.0.4.9005",
17+
"predict_survival(time)",
18+
"predict_survival(eval_time)"
19+
)
20+
eval_time <- time
21+
}
1122

12-
if (inherits(object$fit, "try-error")) {
13-
rlang::warn("Model fit failed; cannot make predictions.")
14-
return(NULL)
15-
}
23+
check_spec_pred_type(object, "survival")
1624

17-
new_data <- prepare_data(object, new_data)
25+
if (inherits(object$fit, "try-error")) {
26+
rlang::warn("Model fit failed; cannot make predictions.")
27+
return(NULL)
28+
}
1829

19-
# preprocess data
20-
if (!is.null(object$spec$method$pred$survival$pre))
21-
new_data <- object$spec$method$pred$survival$pre(new_data, object)
30+
new_data <- prepare_data(object, new_data)
2231

23-
# Pass some extra arguments to be used in post-processor
24-
object$spec$method$pred$survival$args$time <- time
25-
pred_call <- make_pred_call(object$spec$method$pred$survival)
32+
# preprocess data
33+
if (!is.null(object$spec$method$pred$survival$pre))
34+
new_data <- object$spec$method$pred$survival$pre(new_data, object)
2635

27-
res <- eval_tidy(pred_call)
36+
# Pass some extra arguments to be used in post-processor
37+
object$spec$method$pred$survival$args$eval_time <- eval_time
38+
pred_call <- make_pred_call(object$spec$method$pred$survival)
2839

29-
# post-process the predictions
30-
if(!is.null(object$spec$method$pred$survival$post)) {
31-
res <- object$spec$method$pred$survival$post(res, object)
32-
}
40+
res <- eval_tidy(pred_call)
3341

34-
res
42+
# post-process the predictions
43+
if(!is.null(object$spec$method$pred$survival$post)) {
44+
res <- object$spec$method$pred$survival$post(res, object)
3545
}
3646

47+
res
48+
}
49+
3750
#' @export
3851
#' @keywords internal
3952
#' @rdname other_predict

man/other_predict.Rd

Lines changed: 11 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/predict.model_fit.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)