Skip to content

Commit

Permalink
Update documentation #21
Browse files Browse the repository at this point in the history
  • Loading branch information
agosiewska committed Oct 24, 2018
1 parent b29abca commit 7c7cb50
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 50 deletions.
28 changes: 14 additions & 14 deletions R/plot_ceteris_paribus.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @description Function plot for ceteris_paribus object visualise estimated survival curve of mean probabilities in chosen time points. Black lines on each plot correspond to survival curve for our new observation specified in the \code{ceteris_paribus} function.
#'
#' @param x object of class "surv_ceteris_paribus_explainer"
#' @param ... arguments to be passed to methods, such as graphical parameters
#' @param ... arguments to be passed to methods, such as graphical parameters for function \code{\link[ggplot2]{geom_step}}.
#' @param selected_variable name of variable we want to draw ceteris paribus plot
#' @param scale_type type of scale of colors, either "discrete" or "gradient"
#' @param scale_col vector containing values of low and high ends of the gradient, when "gradient" type of scale was chosen
Expand All @@ -22,59 +22,59 @@
#' return(prob)
#' }
#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' y = Surv(pbcTest$years, pbcTest$status), predict_function = predict_times)
#' cp_cph <- ceteris_paribus(surve_cph, pbcTest[1,-c(1,5)])
#' plot(cp_cph)
#' }
#' @method plot surv_ceteris_paribus_explainer
#' @export

plot.surv_ceteris_paribus_explainer <- function(x, ..., selected_variable = NULL, scale_type = "factor",
plot.surv_ceteris_paribus_explainer <- function(x, ..., selected_variable = NULL, scale_type = "factor",
scale_col = NULL, ncol = 1) {

if(!is.null(selected_variable) && !(selected_variable %in% factor(x$vname))){
stop(paste0("Selected variable ", selected_variable, "not present in surv_ceteris_paribus object."))
}

y_hat <- new_x <- time <- time_2 <- y_hat_2 <- NULL
new_observation_legend <- create_legend(x=x)
seq_length <- attributes(x)$grid_points

all_responses <- x

all_predictions <- create_predictions(x)


all_responses <- merge(all_responses, new_observation_legend, by="vname")
if(!is.null(selected_variable)){
all_responses <- all_responses[which(all_responses$vname == selected_variable),]
legend <- unique(all_responses$val)
add_theme <- labs(col = legend)
add_theme <- labs(col = legend)
facet <- NULL
title <- ggtitle(paste("Ceteris paribus plot for variable", selected_variable,"."))
}else{
add_theme <- theme(legend.position = "none")
title <- ggtitle(paste("Ceteris paribus plot for", unique(x$label),"model."))
facet <- facet_wrap(~val, ncol = ncol)
}

#######################
df <- all_responses[,c("vname","new_x")]
df <- unique(df)
df$legend <- 1:nrow(df)
all_responses <- merge(all_responses, df, by=c("vname", "new_x"))

############################
scale <- create_scale(all_responses, scale_type, scale_col, selected_variable)

ggplot(all_responses, aes(x = time, y = y_hat, col = factor(legend))) +
geom_step(...) +
geom_step(data = all_predictions, aes(x = time_2, y = y_hat_2,...), col="black", lty = 2, size = 1) +
scale_y_continuous(breaks = seq(0,1,0.1),
limits = c(0,1),
labels = paste(seq(0,100,10),"%"),
name = "survival probability") +
name = "survival probability") +
facet +
theme_mi2() +
add_theme +
Expand Down Expand Up @@ -125,4 +125,4 @@ create_scale <- function(all_responses, scale_type, scale_col, selected_variable
scale <- NULL
}
return(scale)
}
}
4 changes: 2 additions & 2 deletions R/plot_model_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @description Function plot for surv_model_performance object.
#'
#' @param x object of class "surv_model_performance"
#' @param ... optional, additional object of class "surv_model_performance_explainer"
#' @param ... optional, additional objects of class "surv_model_performance_explainer"
#'
#' @import ggplot2
#'
Expand All @@ -18,7 +18,7 @@
#' return(prob)
#' }
#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)
#'surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#'surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' y = Surv(pbcTest$years, pbcTest$status), predict_function = predict_times)
#' mp_cph <- model_performance(surve_cph, data = pbcTest)
#' plot(mp_cph)
Expand Down
20 changes: 10 additions & 10 deletions R/plot_prediction_breakdown.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @description Function plot for surv_breakdown object visualise estimated survival curve of mean probabilities in chosen time points.
#'
#' @param x an object of class "surv_prediction_breakdown_explainer"
#' @param ... optional, additional object of class "surv_prediction_breakdown_explainer"
#' @param ... optional, additional objects of class "surv_prediction_breakdown_explainer"
#' @param numerate logical; indicating whether we want to number curves
#' @param lines logical; indicating whether we want to add lines on chosen time point or probability
#' @param lines_type a type of line; see http://sape.inf.usi.ch/quick-reference/ggplot2/linetype
Expand All @@ -23,7 +23,7 @@
#' return(prob)
#' }
#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' y = Surv(pbcTest$years, pbcTest$status), predict_function = predict_times)
#' broken_prediction <- prediction_breakdown(surve_cph, pbc[1,-c(1,2)])
#' plot(broken_prediction)
Expand All @@ -33,7 +33,7 @@
#' @importFrom scales seq_gradient_pal
#' @export

plot.surv_prediction_breakdown_explainer <- function(x, ..., numerate = TRUE, lines = TRUE,
plot.surv_prediction_breakdown_explainer <- function(x, ..., numerate = TRUE, lines = TRUE,
lines_type = 1, lines_col = "black",
scale_col = c("#010059","#e0f6fb")){
y <- col <- label <- value <- position <- legend <- NULL
Expand All @@ -55,29 +55,29 @@ plot.surv_prediction_breakdown_explainer <- function(x, ..., numerate = TRUE, li
legend <- theme(legend.position = "none")
}


df <- create_legend_broken(df, x)
#colors
cc <- seq_gradient_pal(scale_col[1],scale_col[2])(seq(0,1,length.out=length(unique(df$legend))))

median_time <- median(unique(df$x))
median <- which.min(abs(unique(df$x) - median_time))
median <- unique(df$x)[median]

if(!is.null(attributes(x)$prob)){
line <- geom_hline(yintercept = attributes(x)$prob, color = lines_col, linetype = lines_type)
}else if (!is.null(attributes(x)$time)){
line <- geom_vline(xintercept = attributes(x)$time, color = lines_col, linetype = lines_type)
}else{
line <- geom_vline(xintercept = median, color = lines_col, linetype = lines_type)
}

if(lines == TRUE){
line <- line
}else{
line <- NULL
}

if(numerate == TRUE){
numbers <- geom_text(data = df[df$x == median,], aes(label = position), color = "black", show.legend = FALSE, hjust = 0, vjust = 0, nudge_x = 0.4)
}else{
Expand All @@ -99,7 +99,7 @@ plot.surv_prediction_breakdown_explainer <- function(x, ..., numerate = TRUE, li
scale_y_continuous(breaks = seq(0,1,0.1),
limits = c(0,1),
labels = paste(seq(0,100,10),"%"),
name = "survival probability") +
name = "survival probability") +
legend


Expand All @@ -113,7 +113,7 @@ create_legend_broken <- function(df, x){
broken_cumm$variable <- as.character(broken_cumm$variable)
broken_cumm <- rbind(broken_cumm, c(round(attributes(x)$Intercept,2), "Intercept"))
broken_cumm <- rbind(broken_cumm, c(round(attributes(x)$Observation,2), "Observation"))

df <- merge(df, broken_cumm, by = "variable")
df$legend <- paste(df$legend, df$contribution)
df$legend <- factor(df$legend, levels = unique(df$legend[order(df$position)]))
Expand Down
4 changes: 2 additions & 2 deletions R/plot_variable_response.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @description Function plot for surv_variable_response object shows the expected output condition on a selected variable.
#'
#' @param x an object of class "surv_variable_response"
#' @param ... optional, additional object of class "surv_variable_response_explainer"
#' @param ... optional, additional objects of class "surv_variable_response_explainer"
#' @param split a character, either "model" or "variable"; sets the variable for faceting
#'
#' @import ggplot2
Expand All @@ -20,7 +20,7 @@
#' return(prob)
#' }
#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' y = Surv(pbcTest$years, pbcTest$status), predict_function = predict_times)
#' svr_cph <- variable_response(surve_cph, "sex")
#' plot(svr_cph)
Expand Down
24 changes: 12 additions & 12 deletions R/prediction_breakdown.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' @param observation a new observation to explain
#' @param time a time point at which variable contributions are computed. If NULL median time is taken.
#' @param prob a survival probability at which variable contributions are computed
#' @param ... other parameters corresponding to arguments from \code{broken} function from \code{breakDown} package. See https://github.com/pbiecek/breakDown/blob/master/R/break_agnostic.R for details
#' @param ... other parameters corresponding to arguments from \code{\link[breakDown]{broken}} function from \code{breakDown} package. See https://github.com/pbiecek/breakDown/blob/master/R/break_agnostic.R for details
#'
#' @return An object of class surv_prediction_breakdown_explainer
#'
Expand All @@ -25,7 +25,7 @@
#' return(prob)
#' }
#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],
#' y = Surv(pbcTest$years, pbcTest$status), predict_function = predict_times)
#' broken_prediction <- prediction_breakdown(surve_cph, pbcTest[1,-c(1,5)])
#' }
Expand All @@ -45,12 +45,12 @@ prediction_breakdown <- function(explainer, observation, time = NULL, prob = NUL
res<- broken(model = explainer$model,
new_observation = observation,
data = explainer$data,
predict.function = new_pred,
predict.function = new_pred,
...)
options(warn = oldw)

class(res) <- "data.frame"

intercept <- res$contribution[res$variable_name=="Intercept"]
observ <- res$contribution[res$variable=="final_prognosis"]

Expand Down Expand Up @@ -92,34 +92,34 @@ prediction_breakdown <- function(explainer, observation, time = NULL, prob = NUL
predict_fun <- function(prob, time, explainer){
if (is.null(prob)) {
if (is.null(time)) time <- median(explainer$times)

new_pred <- function(model, data){
explainer$predict_function(model, data, times = time)
}
} else {
times_sorted <- sort(explainer$times)

find_time <- function(x){
tim <- (x < prob)
index <- c(min(which(tim == TRUE)) -1, min(which(tim == TRUE)))
closest_times <- times_sorted[index]
weighted.mean(closest_times, x[index])
}

new_pred <- function(model, data){
probabilities <- explainer$predict_function(model, data, times = explainer$times)
probabilities <- as.data.frame(probabilities)

res <- apply(probabilities, MARGIN = 1, FUN = find_time)
res <- na.omit(res)
return(res)

}

npred <- new_pred(explainer$model, explainer$data)
message("Number of observations with prob > ", prob, ": ", nrow(explainer$data) - length(npred))
}

return(new_pred)
}

Expand Down Expand Up @@ -157,4 +157,4 @@ calculate_prediction <- function(explainer, tmp_data, times, res, i, variable){
mean_prediction$position <- res[i, "position"]
mean_prediction$value <- res[i, "variable"]
return(mean_prediction)
}
}
4 changes: 2 additions & 2 deletions man/plot.surv_ceteris_paribus_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/plot.surv_model_performance_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/plot.surv_prediction_breakdown_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/plot.surv_variable_response_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/prediction_breakdown.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7c7cb50

Please sign in to comment.