-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_surv_rmtl.r
93 lines (81 loc) · 3.31 KB
/
plot_surv_rmtl.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
## function to plot the restricted mean time lost as it evolves over values of
## the continuous variable
#' @importFrom rlang .data
#' @export
plot_surv_rmtl <- function(time, status, variable, group=NULL,
data, model, na.action=options()$na.action,
tau, horizon=NULL, custom_colors=NULL,
size=1, linetype="solid", alpha=1, color="black",
xlab=variable, ylab="Restricted Mean Time Lost",
title=NULL, subtitle=NULL,
legend.title=variable, legend.position="right",
gg_theme=ggplot2::theme_bw(),
facet_args=list(), ...) {
requireNamespace("dplyr")
data <- use_data.frame(data)
check_inputs_plots(time=time, status=status, variable=variable,
data=data, model=model, na.action=na.action,
horizon=horizon, fixed_t=NULL, max_t=Inf,
discrete=TRUE, panel_border=TRUE, t=1, tau=tau,
group=group)
data <- prepare_inputdata(data=data, time=time, status=status,
variable=variable, model=model,
group=group, na.action=na.action)
if (is.null(horizon)) {
horizon <- seq(min(data[, variable]), max(data[, variable]),
length.out=100)
}
# get plotdata
fixed_t <- c(0, sort(unique(data[, time][data[, status] >= 1])))
plotdata <- curve_cont(data=data,
variable=variable,
group=group,
model=model,
horizon=horizon,
times=fixed_t,
na.action="na.fail",
cif=TRUE,
event_time=time,
event_status=status,
...)
# calculate RMTL values
if (is.null(group)) {
out <- cont_surv_auc(plotdata=plotdata, tau=tau)
} else {
group_levs <- levels(plotdata$group)
out <- vector(mode="list", length=length(group_levs))
for (i in seq_len(length(group_levs))) {
temp <- plotdata[plotdata$group==group_levs[i], ]
out_i <- cont_surv_auc(plotdata=temp, tau=tau)
out_i$group <- group_levs[i]
out[[i]] <- out_i
}
out <- dplyr::bind_rows(out)
}
# plot them
p <- ggplot2::ggplot(out, ggplot2::aes(x=.data$cont, y=.data$auc,
color=.data$tau))
if (length(tau)==1) {
p$mapping$colour <- NULL
gg_line <- ggplot2::geom_line(linewidth=size, linetype=linetype,
alpha=alpha, color=color)
} else {
gg_line <- ggplot2::geom_line(linewidth=size, linetype=linetype,
alpha=alpha)
}
p <- p + gg_line +
ggplot2::labs(x=xlab, y=ylab, title=title, subtitle=subtitle,
fill=legend.title) +
gg_theme +
ggplot2::theme(legend.position=legend.position)
if (length(tau) > 1 & !is.null(custom_colors)) {
p <- p + ggplot2::scale_colour_manual(values=custom_colors)
}
# facet plot by factor variable
if (!is.null(group)) {
facet_args$facets <- stats::as.formula("~ group")
facet_obj <- do.call(ggplot2::facet_wrap, facet_args)
p <- p + facet_obj
}
return(p)
}