Skip to content

Commit d21bd75

Browse files
committed
1 parent 35ffae1 commit d21bd75

File tree

2 files changed

+49
-56
lines changed

2 files changed

+49
-56
lines changed

tests/testthat/_snaps/ipcw.md

Lines changed: 0 additions & 7 deletions
This file was deleted.

tests/testthat/test-ipcw.R

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11

22
test_that('calculate weight time', {
3-
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9003")
3+
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9006")
4+
skip_if_not_installed("censored", minimum_version = "0.1.1.9002")
45

5-
library(survival)
6+
library(tidymodels)
7+
library(censored)
68

79
times <- 1:10
810
cens <- rep(0:1, times = 5)
911

1012
surv_obj <- Surv(times, cens)
13+
n <- length(surv_obj)
1114

12-
eval_0 <- parsnip:::graf_weight_time(surv_obj, eval_time = 0)
13-
eval_05 <- parsnip:::graf_weight_time(surv_obj, eval_time = 5, eps = 1)
14-
eval_11 <- parsnip:::graf_weight_time(surv_obj, eval_time = 11, rows = 11:20, eps = 0)
15+
eval_0 <- parsnip:::graf_weight_time_vec(surv_obj, eval_time = rep(0, n))
16+
eval_05 <- parsnip:::graf_weight_time_vec(surv_obj, eval_time = rep(5, n), eps = 1)
17+
eval_11 <- parsnip:::graf_weight_time_vec(surv_obj, eval_time = rep(11, n), eps = 0)
1518

16-
na_05 <- is.na(eval_05$weight_time)
17-
na_11 <- is.na(eval_11$weight_time)
19+
na_05 <- is.na(eval_05)
20+
na_11 <- is.na(eval_11)
1821

19-
expect_equal(eval_0$weight_time, rep(0, 10))
20-
expect_equal(eval_0$.row, 1:10)
22+
expect_equal(eval_0, rep(0, 10))
2123

2224
expect_equal(
2325
which(na_05),
2426
which(times <= 5 & cens == 0)
2527
)
2628
expect_equal(
27-
eval_05$weight_time[!na_05],
29+
eval_05[!na_05],
2830
ifelse(times[!na_05] - 1 < 5, times[!na_05] - 1, 4)
2931
)
3032

@@ -33,62 +35,60 @@ test_that('calculate weight time', {
3335
which(cens == 0)
3436
)
3537
expect_equal(
36-
eval_11$weight_time[!na_11],
37-
(1:5) * 2
38+
eval_11[!na_11],
39+
seq(2, 10, by = 2)
3840
)
39-
expect_equal(eval_11$.row, 11:20)
4041

4142
})
4243

4344
test_that('compute Graf weights', {
44-
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9003")
45+
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9006")
46+
skip_if_not_installed("censored", minimum_version = "0.1.1.9002")
4547

46-
library(parsnip)
47-
library(survival)
48+
library(tidymodels)
4849
library(censored)
49-
library(workflows)
50-
library(dplyr)
5150

52-
times <- 1:10
53-
cens <- c(0, rep(1, 9))
51+
times <- c(9, 1:9)
52+
cens <- rep(0:1, 5)
5453
surv_obj <- Surv(times, cens)
54+
n <- length(surv_obj)
5555
df <- data.frame(surv = surv_obj, x = -1:8)
5656
fit <- survival_reg() %>% fit(surv ~ x, data = df)
5757
wflow_fit <-
5858
workflow() %>%
5959
add_model(survival_reg(), formula = surv ~ x) %>%
6060
add_variables(surv, x) %>%
6161
fit(data = df)
62+
mod_fit <- extract_fit_parsnip(wflow_fit)
63+
64+
eval_times <- c(5, 1:4)
65+
66+
pred_surv <-
67+
predict(mod_fit, df, type = "survival", eval_time = eval_times) %>%
68+
bind_cols(
69+
predict(mod_fit, df, type = "time"),
70+
df
71+
) %>%
72+
slice(5)
73+
74+
wt_times <-
75+
parsnip:::graf_weight_time_vec(pred_surv$surv,
76+
eval_time = pred_surv$.pred[[1]]$.eval_time)
77+
expect_equal(wt_times, c(NA, 0.9999999999, 1.9999999999, 2.9999999999, NA), tolerance = 0.01)
78+
79+
cens_probs <- predict(fit$censor_probs, time = wt_times, as_vector = TRUE)
80+
81+
wts <- .censoring_weights_graf(fit, pred_surv)
82+
expect_equal(names(wts), names(pred_surv))
83+
expect_equal(nrow(wts), nrow(pred_surv))
84+
expect_equal(dim(wts$.pred[[1]]), c(length(eval_times), 5))
85+
expect_equal(wts$.pred[[1]]$.eval_time, eval_times)
86+
expect_equal(
87+
names(wts$.pred[[1]]),
88+
c(".eval_time", ".pred_survival", ".weight_time", ".pred_censored", ".weight_censored"))
6289

63-
eval_0 <- parsnip:::graf_weight_time(surv_obj, eval_time = 0)
64-
eval_05 <- parsnip:::graf_weight_time(surv_obj, eval_time = 5, eps = 1)
65-
eval_11 <- parsnip:::graf_weight_time(surv_obj, eval_time = 11, rows = 11:20, eps = 0)
66-
67-
cens_prob_00 <- predict(fit$censor_probs, time = eval_0$weight_time, as_vector = TRUE)
68-
cens_prob_05 <- predict(fit$censor_probs, time = eval_05$weight_time, as_vector = TRUE)
69-
cens_prob_11 <- predict(fit$censor_probs, time = eval_11$weight_time, as_vector = TRUE)
70-
71-
wts_00 <- .censoring_weights_graf(fit, df, 0)
72-
wts_05 <- .censoring_weights_graf(fit, df, 5)
73-
wts_11 <- .censoring_weights_graf(fit, df, 11)
74-
75-
wflow_wts_00 <- .censoring_weights_graf(wflow_fit, df, 0)
76-
wflow_wts_05 <- .censoring_weights_graf(wflow_fit, df, 5)
77-
wflow_wts_11 <- .censoring_weights_graf(wflow_fit, df, 11)
78-
79-
expect_equal(wts_00$.weight_cens, 1 / cens_prob_00)
80-
expect_equal(wts_05$.weight_cens, 1 / cens_prob_05)
81-
expect_equal(wts_11$.weight_cens, 1 / cens_prob_11)
82-
83-
expect_equal(wflow_wts_00$.weight_cens, 1 / cens_prob_00)
84-
expect_equal(wflow_wts_05$.weight_cens, 1 / cens_prob_05)
85-
expect_equal(wflow_wts_11$.weight_cens, 1 / cens_prob_11)
86-
87-
expect_true(inherits(wts_00, "data.frame"))
88-
expect_equal(names(wts_00), c(".row", "eval_time", ".prob_cens", ".weight_cens"))
89-
expect_equal(nrow(wts_00), nrow(df))
90-
91-
expect_snapshot(.censoring_weights_graf(2, df, 0), error = TRUE)
90+
wts2 <- wts %>% unnest(.pred)
91+
expect_equal(wts2$.weight_censored, 1 / cens_probs)
9292

9393
})
9494

0 commit comments

Comments
 (0)