Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ A minor update to the package with some bug fixes and minor changes.
- Removed the on attach message which warned of breaking changes in `1.0.0`.
- Renamed the `metric` argument of `summarise_scores()` to `relative_skill_metric`. This argument is now deprecated and will be removed in a future version of the package. Please use the new argument instead.
- Updated the documentation for `score()` and related functions to make the soft requirement for a `model` column in the input data more explicit.
- Simplified the function `plot_pairwise_comparison()` which now only supports plotting mean score ratios or p-values and removed the hybrid option to print both at the same time.

## Bug fixes

Expand Down
19 changes: 5 additions & 14 deletions R/pairwise-comparisons.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,12 @@
#' @author Johannes Bracher, \email{johannes.bracher@@kit.edu}
#' @keywords scoring
#' @examples
#' df <- data.frame(
#' model = rep(c("model1", "model2", "model3"), each = 10),
#' date = as.Date("2020-01-01") + rep(1:5, each = 2),
#' location = c(1, 2),
#' interval_score = (abs(rnorm(30))),
#' ae_median = (abs(rnorm(30)))
#' )
#' scores <- score(example_quantile)
#' pairwise <- pairwise_comparison(scores, by = "target_type")
#'
#' res <- pairwise_comparison(df,
#' baseline = "model1"
#' )
#' plot_pairwise_comparison(res)
#'
#' eval <- score(example_quantile)
#' pairwise_comparison(eval, by = c("model"))
#' library(ggplot2)
#' plot_pairwise_comparison(pairwise, type = "mean_scores_ratio") +
#' facet_wrap(~target_type)

pairwise_comparison <- function(scores,
by = c("model"),
Expand Down
205 changes: 28 additions & 177 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -690,13 +690,9 @@ plot_quantile_coverage <- function(scores,
#' @param comparison_result A data.frame as produced by
#' [pairwise_comparison()]
#' @param type character vector of length one that is either
#' "mean_scores_ratio", "pval", or "together". This denotes whether to
#' visualise the ratio or the p-value of the pairwise comparison or both.
#' "mean_scores_ratio" or "pval". This denotes whether to
#' visualise the ratio or the p-value of the pairwise comparison.
#' Default is "mean_scores_ratio".
#' @param smaller_is_good logical (default is `TRUE`) that indicates whether
#' smaller or larger values are to be interpreted as 'good' (as you could just
#' invert the mean scores ratio). This option is not supported when type =
#' "pval"
#' @importFrom ggplot2 ggplot aes geom_tile geom_text labs coord_cartesian
#' scale_fill_gradient2 theme_light element_text
#' @importFrom data.table as.data.table setnames rbindlist
Expand All @@ -706,27 +702,18 @@ plot_quantile_coverage <- function(scores,
#' @export
#' @examples
#' library(ggplot2)
#' df <- data.frame(
#' model = rep(c("model1", "model2", "model3"), each = 10),
#' id = rep(1:10),
#' interval_score = abs(rnorm(30, mean = rep(c(1, 1.3, 2), each = 10))),
#' ae_median = (abs(rnorm(30)))
#' )
#'
#' scores <- score(example_quantile)
#' pairwise <- pairwise_comparison(scores, by = "target_type")
#' plot_pairwise_comparison(pairwise) +
#' plot_pairwise_comparison(pairwise, type = "mean_scores_ratio") +
#' facet_wrap(~target_type)

plot_pairwise_comparison <- function(comparison_result,
type = c("mean_scores_ratio", "pval", "together"),
smaller_is_good = TRUE) {
type = c("mean_scores_ratio", "pval")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match.arg to check allowed options used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does use match.arg in line 782 - is that different from what you meant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For type? If it does then no that is exactly what I meant

comparison_result <- data.table::as.data.table(comparison_result)

comparison_result[, model := reorder(model, -relative_skill)]
levels <- levels(comparison_result$model)


get_fill_scale <- function(values, breaks, plot_scales) {
values[is.na(values)] <- 1 # this would be either ratio = 1 or pval = 1
scale <- cut(values,
Expand All @@ -735,176 +722,40 @@ plot_pairwise_comparison <- function(comparison_result,
right = FALSE,
labels = plot_scales
)
# scale[is.na(scale)] <- 0
return(as.numeric(as.character(scale)))
}

type <- match.arg(type)

if (type == "together") {
# obtain only the upper triangle of the comparison
# that is used for showing ratios
# need to change the order if larger is good
if (smaller_is_good) {
unique_comb <- as.data.frame(t(combn(rev(levels), 2)))
} else {
unique_comb <- as.data.frame(t(combn((levels), 2)))
}

colnames(unique_comb) <- c("model", "compare_against")
upper_triangle <- merge(comparison_result, unique_comb)

# change levels for plotting order
upper_triangle[, `:=`(
model = factor(model, levels),
compare_against = factor(compare_against, levels)
)]

# reverse y and x if larger is better
if (!smaller_is_good) {
data.table::setnames(
upper_triangle,
c("model", "compare_against"),
c("compare_against", "model")
)
}
if (type == "mean_scores_ratio") {
comparison_result[, var_of_interest := round(mean_scores_ratio, 2)]

# modify upper triangle ------------------------------------------------------
# add columns where a model is compared with itself. make adj_pval NA
# to plot it as grey later on
equal <- data.table::data.table(
model = levels,
compare_against = levels,
mean_scores_ratio = 1,
pval = NA,
adj_pval = NA
)
upper_triangle_complete <- data.table::rbindlist(list(
upper_triangle,
equal
), fill = TRUE)

# define interest variable
upper_triangle_complete[, var_of_interest := round(mean_scores_ratio, 2)]

# implemnt breaks for colour heatmap
# implement breaks for colour heatmap
breaks <- c(0, 0.1, 0.5, 0.75, 1, 1.33, 2, 10, Inf)
plot_scales <- c(-1, -0.5, -0.25, 0, 0, 0.25, 0.5, 1)
if (!smaller_is_good) {
plot_scales <- rev(plot_scales)
}
upper_triangle_complete[, fill_col := get_fill_scale(
comparison_result[, fill_col := get_fill_scale(
var_of_interest,
breaks, plot_scales
)]

# create mean_scores_ratios in plot
plot <- ggplot(
upper_triangle_complete,
aes(
x = compare_against,
y = model,
fill = fill_col
)
) +
geom_tile(width = 0.98, height = 0.98) +
geom_text(aes(label = var_of_interest),
na.rm = TRUE
) +
scale_fill_gradient2(
low = "steelblue", mid = "grey95",
high = "salmon",
na.value = "lightgrey",
midpoint = 0,
limits = c(-1, 1),
name = NULL
) +
theme_scoringutils() +
theme(
axis.text.x = element_text(
angle = 90, vjust = 1,
hjust = 1, color = "brown4"
),
axis.text.y = element_text(color = "steelblue4"),
legend.position = "none"
) +
labs(
x = "", y = "",
title = "Pairwise comparisons - mean_scores_ratio (upper) and pval (lower)"
) +
coord_cartesian(expand = FALSE)

# add pvalues to plot --------------------------------------------------------
# obtain lower triangle for the pvalues
lower_triangle <- data.table::copy(upper_triangle)
data.table::setnames(
lower_triangle,
c("model", "compare_against"),
c("compare_against", "model")
)

lower_triangle[, var_of_interest := round(adj_pval, 3)]
high_col <- "salmon"
} else if (type == "pval") {
comparison_result[, var_of_interest := round(pval, 3)]
# implemnt breaks for colour heatmap
breaks <- c(0, 0.01, 0.05, 0.1, 1)
plot_scales <- c(0.8, 0.5, 0.1, 0.000001)
lower_triangle[, fill_col := get_fill_scale(
plot_scales <- c(1, 0.5, 0.1, 0)
comparison_result[, fill_col := get_fill_scale(
var_of_interest,
breaks, plot_scales
)]

fill_rule <- ifelse(
lower_triangle$fill_col == 0.000001, "grey95", "palegreen3"
)
lower_triangle[, var_of_interest := as.character(var_of_interest)]
lower_triangle[, var_of_interest := ifelse(var_of_interest == "0",
"< 0.001", var_of_interest
)]

plot <- plot +
geom_tile(
data = lower_triangle,
aes(alpha = fill_col),
fill = fill_rule,
color = "white",
width = 0.97, height = 0.97
) +
geom_text(
data = lower_triangle,
aes(label = var_of_interest),
na.rm = TRUE
)
} else{
if (type == "mean_scores_ratio") {
comparison_result[, var_of_interest := round(mean_scores_ratio, 2)]

# implemnt breaks for colour heatmap
breaks <- c(0, 0.1, 0.5, 0.75, 1, 1.33, 2, 10, Inf)
plot_scales <- c(-1, -0.5, -0.25, 0, 0, 0.25, 0.5, 1)
comparison_result[, fill_col := get_fill_scale(
var_of_interest,
breaks, plot_scales
)]

high_col <- "salmon"
} else {
if (!smaller_is_good) {
stop("smaller_is_good is the only supported option with type pval")
}
comparison_result[, var_of_interest := round(pval, 3)]
# implemnt breaks for colour heatmap
breaks <- c(0, 0.01, 0.05, 0.1, 1)
plot_scales <- c(1, 0.5, 0.1, 0)
comparison_result[, fill_col := get_fill_scale(
var_of_interest,
breaks, plot_scales
)]

high_col <- "palegreen3"
high_col <- "palegreen3"
comparison_result[, var_of_interest := as.character(var_of_interest)]
comparison_result[, var_of_interest := ifelse(var_of_interest == "0",
"< 0.001", var_of_interest
"< 0.001", var_of_interest
)]
}
}

plot <- ggplot(
comparison_result,
aes(
Expand All @@ -918,7 +769,7 @@ plot_pairwise_comparison <- function(comparison_result,
width = 0.97, height = 0.97
) +
geom_text(aes(label = var_of_interest),
na.rm = TRUE
na.rm = TRUE
) +
scale_fill_gradient2(
low = "steelblue", mid = "grey95",
Expand All @@ -940,21 +791,21 @@ plot_pairwise_comparison <- function(comparison_result,
x = "", y = ""
) +
coord_cartesian(expand = FALSE)
if (type == "mean_scores_ratio") {
plot <- plot +
theme(
axis.text.x = element_text(
angle = 90, vjust = 1,
hjust = 1, color = "brown4"
),
axis.text.y = element_text(color = "steelblue4")
)
}
if (type == "mean_scores_ratio") {
plot <- plot +
theme(
axis.text.x = element_text(
angle = 90, vjust = 1,
hjust = 1, color = "brown4"
),
axis.text.y = element_text(color = "steelblue4")
)
}

return(plot)
}


#' @title PIT Histogram
#'
#' @description
Expand Down
19 changes: 5 additions & 14 deletions man/pairwise_comparison.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 4 additions & 17 deletions man/plot_pairwise_comparison.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading