-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_surv_matrix.r
159 lines (141 loc) · 6.11 KB
/
plot_surv_matrix.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
## Survival heatmap using only a discrete number of tiles,
## but with underlying smooth estimation
#' @importFrom rlang .data
#' @export
plot_surv_matrix <- function(time, status, variable, group=NULL, data, model,
cif=FALSE, na.action=options()$na.action,
horizon=NULL, fixed_t=NULL, max_t=Inf,
n_col=10, n_row=10,
start_color="red", end_color="blue",
alpha=1, xlab="Time", ylab=variable,
title=NULL, subtitle=NULL,
legend.title="S(t)", legend.position="none",
gg_theme=ggplot2::theme_bw(),
facet_args=list(),
panel_border=FALSE, axis_dist=0,
border_color="white", border_size=0.5,
numbers=TRUE, number_color="white",
number_size=3, number_family="sans",
number_fontface="plain", number_digits=2,
...) {
# silence devtools::check()
rect_id <- cont <- est <- NULL
data <- use_data.frame(data)
# standard input checks
check_inputs_plots(time=time, status=status, variable=variable,
data=data, model=model, na.action=na.action,
horizon=horizon, fixed_t=fixed_t, max_t=max_t,
discrete=TRUE, panel_border=panel_border, t=1, tau=1,
group=group)
# further special input checks
check_inputs_surv_matrix(border_color=border_color, border_size=border_size,
numbers=numbers, number_color=number_color,
number_size=number_size, number_family=number_family,
number_fontface=number_fontface,
number_digits=number_digits, fixed_t=fixed_t,
horizon=horizon, n_col=n_col, n_row=n_row)
data <- prepare_inputdata(data=data, time=time, status=status,
variable=variable, model=model,
group=group, na.action=na.action)
if (is.null(fixed_t)) {
fixed_t <- seq(min(data[, time]), max(data[, time]), length.out=100)
}
if (is.null(horizon)) {
horizon <- seq(min(data[, variable]), max(data[, variable]), length.out=100)
}
# final input checks for n_col, n_row
if (n_col > length(fixed_t)) {
stop("'n_col' must be smaller than length(fixed_t). Decrease n_col or",
" increase the number of points in time used in the estimation.")
}
if (n_row > length(horizon)) {
stop("'n_row' must be smaller than length(horizon). Decrease n_row or",
" increase the number of values in horizon used in the estimation.")
}
# only show up to max_t
fixed_t <- fixed_t[fixed_t <= max_t]
# get plotdata
plotdata <- curve_cont(data=data,
variable=variable,
model=model,
group=group,
horizon=horizon,
times=fixed_t,
na.action="na.fail",
cif=cif,
event_time=time,
event_status=status,
...)
# transform plotdata
plotdata$time_cut <- cut(plotdata$time, n_col)
plotdata$cont_cut <- cut(plotdata$cont, n_row)
plotdata$rect_id <- paste(plotdata$time_cut, plotdata$cont_cut)
if (is.null(group)) {
plotdata <- plotdata %>%
dplyr::group_by(rect_id) %>%
dplyr::summarise(xmin=min(time),
xmax=max(time),
ymin=min(cont),
ymax=max(cont),
est=mean(est),
.groups="drop_last")
} else {
plotdata <- plotdata %>%
dplyr::group_by(rect_id, group) %>%
dplyr::summarise(xmin=min(time),
xmax=max(time),
ymin=min(cont),
ymax=max(cont),
est=mean(est),
.groups="drop_last")
}
# close gap between tiles
gap_x <- fixed_t[2] - fixed_t[1]
gap_y <- abs(horizon[2] - horizon[1])
plotdata$xmax <- plotdata$xmax + gap_x
plotdata$ymax <- plotdata$ymax + gap_y
if (numbers) {
# get coordinates for text
plotdata$x_text <- rowMeans(plotdata[, c("xmax", "xmin")])
plotdata$y_text <- rowMeans(plotdata[, c("ymax", "ymin")])
# round the numbers
plotdata$label <- format(round(plotdata$est, number_digits),
nsmall=number_digits)
}
# correct label
if (cif & legend.title=="S(t)") {
legend.title <- "F(t)"
}
# plot it
p <- ggplot2::ggplot(plotdata, ggplot2::aes(fill=.data$est, xmin=.data$xmin,
xmax=.data$xmax, ymin=.data$ymin,
ymax=.data$ymax)) +
ggplot2::geom_rect(color=border_color, alpha=alpha, linewidth=border_size) +
ggplot2::labs(x=xlab, y=ylab, title=title, subtitle=subtitle,
fill=legend.title) +
gg_theme +
ggplot2::theme(legend.position=legend.position) +
ggplot2::scale_x_continuous(expand=c(axis_dist, axis_dist)) +
ggplot2::scale_y_continuous(expand=c(axis_dist, axis_dist))
if (!is.null(start_color) & !is.null(end_color)) {
p <- p + ggplot2::scale_fill_gradient(low=start_color, high=end_color)
}
if (!panel_border) {
p <- p + ggplot2::theme(panel.border=ggplot2::element_blank())
}
if (numbers) {
p <- p + ggplot2::geom_text(ggplot2::aes(x=.data$x_text, y=.data$y_text,
label=.data$label),
color=number_color,
size=number_size,
family=number_family,
fontface=number_fontface)
}
# facet plot by factor variable
if (!is.null(group)) {
facet_args$facets <- stats::as.formula("~ group")
facet_obj <- do.call(ggplot2::facet_wrap, facet_args)
p <- p + facet_obj
}
return(p)
}