1
1
2
2
test_that(' calculate weight time' , {
3
- skip_if_not_installed(" parsnip" , minimum_version = " 1.0.4.9003" )
3
+ skip_if_not_installed(" parsnip" , minimum_version = " 1.0.4.9006" )
4
+ skip_if_not_installed(" censored" , minimum_version = " 0.1.1.9002" )
4
5
5
- library(survival )
6
+ library(tidymodels )
7
+ library(censored )
6
8
7
9
times <- 1 : 10
8
10
cens <- rep(0 : 1 , times = 5 )
9
11
10
12
surv_obj <- Surv(times , cens )
13
+ n <- length(surv_obj )
11
14
12
- eval_0 <- parsnip ::: graf_weight_time (surv_obj , eval_time = 0 )
13
- eval_05 <- parsnip ::: graf_weight_time (surv_obj , eval_time = 5 , eps = 1 )
14
- eval_11 <- parsnip ::: graf_weight_time (surv_obj , eval_time = 11 , rows = 11 : 20 , eps = 0 )
15
+ eval_0 <- parsnip ::: graf_weight_time_vec (surv_obj , eval_time = rep( 0 , n ) )
16
+ eval_05 <- parsnip ::: graf_weight_time_vec (surv_obj , eval_time = rep( 5 , n ) , eps = 1 )
17
+ eval_11 <- parsnip ::: graf_weight_time_vec (surv_obj , eval_time = rep( 11 , n ) , eps = 0 )
15
18
16
- na_05 <- is.na(eval_05 $ weight_time )
17
- na_11 <- is.na(eval_11 $ weight_time )
19
+ na_05 <- is.na(eval_05 )
20
+ na_11 <- is.na(eval_11 )
18
21
19
- expect_equal(eval_0 $ weight_time , rep(0 , 10 ))
20
- expect_equal(eval_0 $ .row , 1 : 10 )
22
+ expect_equal(eval_0 , rep(0 , 10 ))
21
23
22
24
expect_equal(
23
25
which(na_05 ),
24
26
which(times < = 5 & cens == 0 )
25
27
)
26
28
expect_equal(
27
- eval_05 $ weight_time [! na_05 ],
29
+ eval_05 [! na_05 ],
28
30
ifelse(times [! na_05 ] - 1 < 5 , times [! na_05 ] - 1 , 4 )
29
31
)
30
32
@@ -33,62 +35,60 @@ test_that('calculate weight time', {
33
35
which(cens == 0 )
34
36
)
35
37
expect_equal(
36
- eval_11 $ weight_time [! na_11 ],
37
- ( 1 : 5 ) * 2
38
+ eval_11 [! na_11 ],
39
+ seq( 2 , 10 , by = 2 )
38
40
)
39
- expect_equal(eval_11 $ .row , 11 : 20 )
40
41
41
42
})
42
43
43
44
test_that(' compute Graf weights' , {
44
- skip_if_not_installed(" parsnip" , minimum_version = " 1.0.4.9003" )
45
+ skip_if_not_installed(" parsnip" , minimum_version = " 1.0.4.9006" )
46
+ skip_if_not_installed(" censored" , minimum_version = " 0.1.1.9002" )
45
47
46
- library(parsnip )
47
- library(survival )
48
+ library(tidymodels )
48
49
library(censored )
49
- library(workflows )
50
- library(dplyr )
51
50
52
- times <- 1 : 10
53
- cens <- c( 0 , rep(1 , 9 ) )
51
+ times <- c( 9 , 1 : 9 )
52
+ cens <- rep(0 : 1 , 5 )
54
53
surv_obj <- Surv(times , cens )
54
+ n <- length(surv_obj )
55
55
df <- data.frame (surv = surv_obj , x = - 1 : 8 )
56
56
fit <- survival_reg() %> % fit(surv ~ x , data = df )
57
57
wflow_fit <-
58
58
workflow() %> %
59
59
add_model(survival_reg(), formula = surv ~ x ) %> %
60
60
add_variables(surv , x ) %> %
61
61
fit(data = df )
62
+ mod_fit <- extract_fit_parsnip(wflow_fit )
63
+
64
+ eval_times <- c(5 , 1 : 4 )
65
+
66
+ pred_surv <-
67
+ predict(mod_fit , df , type = " survival" , eval_time = eval_times ) %> %
68
+ bind_cols(
69
+ predict(mod_fit , df , type = " time" ),
70
+ df
71
+ ) %> %
72
+ slice(5 )
73
+
74
+ wt_times <-
75
+ parsnip ::: graf_weight_time_vec(pred_surv $ surv ,
76
+ eval_time = pred_surv $ .pred [[1 ]]$ .eval_time )
77
+ expect_equal(wt_times , c(NA , 0.9999999999 , 1.9999999999 , 2.9999999999 , NA ), tolerance = 0.01 )
78
+
79
+ cens_probs <- predict(fit $ censor_probs , time = wt_times , as_vector = TRUE )
80
+
81
+ wts <- .censoring_weights_graf(fit , pred_surv )
82
+ expect_equal(names(wts ), names(pred_surv ))
83
+ expect_equal(nrow(wts ), nrow(pred_surv ))
84
+ expect_equal(dim(wts $ .pred [[1 ]]), c(length(eval_times ), 5 ))
85
+ expect_equal(wts $ .pred [[1 ]]$ .eval_time , eval_times )
86
+ expect_equal(
87
+ names(wts $ .pred [[1 ]]),
88
+ c(" .eval_time" , " .pred_survival" , " .weight_time" , " .pred_censored" , " .weight_censored" ))
62
89
63
- eval_0 <- parsnip ::: graf_weight_time(surv_obj , eval_time = 0 )
64
- eval_05 <- parsnip ::: graf_weight_time(surv_obj , eval_time = 5 , eps = 1 )
65
- eval_11 <- parsnip ::: graf_weight_time(surv_obj , eval_time = 11 , rows = 11 : 20 , eps = 0 )
66
-
67
- cens_prob_00 <- predict(fit $ censor_probs , time = eval_0 $ weight_time , as_vector = TRUE )
68
- cens_prob_05 <- predict(fit $ censor_probs , time = eval_05 $ weight_time , as_vector = TRUE )
69
- cens_prob_11 <- predict(fit $ censor_probs , time = eval_11 $ weight_time , as_vector = TRUE )
70
-
71
- wts_00 <- .censoring_weights_graf(fit , df , 0 )
72
- wts_05 <- .censoring_weights_graf(fit , df , 5 )
73
- wts_11 <- .censoring_weights_graf(fit , df , 11 )
74
-
75
- wflow_wts_00 <- .censoring_weights_graf(wflow_fit , df , 0 )
76
- wflow_wts_05 <- .censoring_weights_graf(wflow_fit , df , 5 )
77
- wflow_wts_11 <- .censoring_weights_graf(wflow_fit , df , 11 )
78
-
79
- expect_equal(wts_00 $ .weight_cens , 1 / cens_prob_00 )
80
- expect_equal(wts_05 $ .weight_cens , 1 / cens_prob_05 )
81
- expect_equal(wts_11 $ .weight_cens , 1 / cens_prob_11 )
82
-
83
- expect_equal(wflow_wts_00 $ .weight_cens , 1 / cens_prob_00 )
84
- expect_equal(wflow_wts_05 $ .weight_cens , 1 / cens_prob_05 )
85
- expect_equal(wflow_wts_11 $ .weight_cens , 1 / cens_prob_11 )
86
-
87
- expect_true(inherits(wts_00 , " data.frame" ))
88
- expect_equal(names(wts_00 ), c(" .row" , " eval_time" , " .prob_cens" , " .weight_cens" ))
89
- expect_equal(nrow(wts_00 ), nrow(df ))
90
-
91
- expect_snapshot(.censoring_weights_graf(2 , df , 0 ), error = TRUE )
90
+ wts2 <- wts %> % unnest(.pred )
91
+ expect_equal(wts2 $ .weight_censored , 1 / cens_probs )
92
92
93
93
})
94
94
0 commit comments