From 60a4299d3dc9e5cc8c577ade6f7e00b46e984696 Mon Sep 17 00:00:00 2001 From: moralapablo Date: Mon, 16 Dec 2024 17:32:10 +0100 Subject: [PATCH 1/4] Added heatmap function --- R/nn2poly_methods.R | 99 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/R/nn2poly_methods.R b/R/nn2poly_methods.R index fc922d5..ef37616 100644 --- a/R/nn2poly_methods.R +++ b/R/nn2poly_methods.R @@ -353,3 +353,102 @@ plot.nn2poly <- function(x, ..., n=NULL) { return(plot_all) } + +#' Heatmap for Second-Order Polynomial Terms +#' +#' This function generates a heatmap to visualize second-order terms in the polynomial representation of a neural network. +#' It displays squared terms along the diagonal and pairwise interactions off-diagonal, with colors indicating the magnitude and sign of the coefficients. +#' +#' @param x A `nn2poly` object, as returned by the `nn2poly` algorithm. +#' @param ... Additional arguments (unused). +#' @param max_order Integer, defaults to 2. This function currently supports only second-order terms. +#' +#' @return A ggplot object showing the heatmap of second-order terms. +#' +#' @details +#' Coefficients are displayed as a gradient from red (negative) to green (positive), with white indicating zero. +#' The diagonal contains squared terms (e.g., \eqn{x_1^2}), while off-diagonal entries represent pairwise interactions (e.g., \eqn{x_1x_2}). +#' +#' @examples +#' # Example: Single output polynomial with 20 variables +#' set.seed(42) +#' weights_layer_1 <- matrix(rnorm(21 * 10), nrow = 21, ncol = 10) +#' weights_layer_2 <- matrix(rnorm(11 * 1), nrow = 11, ncol = 1) +#' +#' nn_object <- list("softplus" = weights_layer_1, "linear" = weights_layer_2) +#' +#' # Generate the polynomial representation +#' final_poly <- nn2poly(nn_object, max_order = 2) +#' +#' # Plot the heatmap for second-order terms +#' heatmap.nn2poly(final_poly) +#' +#' +heatmap.nn2poly <- function(x, ..., max_order = 2) { + if (!requireNamespace("ggplot2", quietly = TRUE)) { + stop("package 'ggplot2' is required for this functionality", call. = FALSE) + } + + if (max_order != 2) { + stop("This function currently supports only second-order terms (max_order = 2).", call. = FALSE) + } + + # Ensure the object is an nn2poly object + if (length(class(x)) > 1) { + return(NextMethod()) + } + + # Extract polynomial coefficients and labels + coefficients <- x$values + labels <- x$labels + + if (is.null(coefficients) || is.null(labels)) { + stop("Invalid nn2poly object: missing coefficients or labels.", call. = FALSE) + } + + if (is.vector(coefficients)) { + coefficients <- matrix(coefficients, ncol = 1) + } + + # Prepare data for second-order terms + second_order_indices <- which(sapply(labels, function(label) length(label) <= max_order)) + second_order_labels <- labels[second_order_indices] + second_order_values <- coefficients[second_order_indices, , drop = FALSE] + + # Create a matrix for the heatmap + max_vars <- max(unlist(second_order_labels)) + heatmap_matrix <- matrix(0, nrow = max_vars, ncol = max_vars, dimnames = list(paste0("x", 1:max_vars), paste0("x", 1:max_vars))) + + for (i in seq_along(second_order_labels)) { + label <- second_order_labels[[i]] + value <- second_order_values[i, 1] + + if (length(label) == 1) { + # Single variable squared + heatmap_matrix[label, label] <- value + } else if (length(label) == 2) { + # Interaction terms + heatmap_matrix[label[1], label[2]] <- value + heatmap_matrix[label[2], label[1]] <- value + } + } + + # Convert matrix to long format for ggplot2 + heatmap_df <- as.data.frame(as.table(heatmap_matrix)) + colnames(heatmap_df) <- c("Var1", "Var2", "Value") + + # Generate the heatmap plot + heatmap_plot <- ggplot2::ggplot(heatmap_df, ggplot2::aes(x = Var1, y = Var2, fill = Value)) + + ggplot2::geom_tile(color = "white") + + ggplot2::scale_fill_gradient2(low = "#F8766D", mid = "white", high = "#00BA38", midpoint = 0, + name = "Coefficient") + + ggplot2::theme_minimal() + + ggplot2::labs(x = "Variables", y = "Variables", title = "Second-Order Terms Heatmap") + + ggplot2::theme( + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), + axis.text.y = ggplot2::element_text(size = 10) + ) + + return(heatmap_plot) +} + From cd24f3a0d72da8888acdb07d3b50652b2473ae1a Mon Sep 17 00:00:00 2001 From: moralapablo Date: Wed, 7 May 2025 23:55:43 +0200 Subject: [PATCH 2/4] multiple new experimental plots --- DESCRIPTION | 4 +- R/eval_poly.R | 69 +- R/nn2poly_methods.R | 1384 ++++++++++++++--- man/plot.nn2poly.Rd | 33 +- vignettes/source/_nn2poly-01-introduction.Rmd | 121 +- 5 files changed, 1361 insertions(+), 250 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 727b2d7..8ad794f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Imports: Suggests: keras, tensorflow, reticulate, luz, torch, - cowplot, ggplot2, patchwork, + cowplot, ggplot2, patchwork, ggbeesawarm testthat (>= 3.0.0), vdiffr, knitr, rmarkdown LinkingTo: @@ -37,6 +37,6 @@ LinkingTo: RcppArmadillo VignetteBuilder: knitr Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 Config/testthat/edition: 3 URL: https://ibidat.github.io/nn2poly/ diff --git a/R/eval_poly.R b/R/eval_poly.R index b2e9a83..0efe0ce 100644 --- a/R/eval_poly.R +++ b/R/eval_poly.R @@ -53,7 +53,13 @@ eval_poly <- function(poly, newdata, monomials = FALSE) { aux <- preprocess_poly(poly) poly <- aux$poly - intercept_position <- aux$intercept_position + original_intercept_pos_for_reordering <- aux$intercept_position + + # Check if the *current first term* (after potential reordering by preprocess_poly) is the intercept + first_term_is_intercept <- FALSE + if (length(poly$labels) > 0 && length(poly$labels[[1]]) == 1 && poly$labels[[1]][1] == 0) { + first_term_is_intercept <- TRUE + } n_sample <- nrow(newdata) n_polynomials <- ncol(poly$values) @@ -70,16 +76,15 @@ eval_poly <- function(poly, newdata, monomials = FALSE) { # Select the desired polynomial values (column of poly$values) values_k <- poly$values[,k] - # If poly has no intercept if intercept_position is NULL - if (is.null(intercept_position)){ + if (first_term_is_intercept){ + # Initialize the vector with the intercept value repeated as needed. + response[,1,k] <- rep(values_k[1], nrow(newdata)) + start_loop <- 2 + } else { # Intercept (label = 0) should always be the first element of labels at this # point of the function (labels reordered previously in preprocess_poly). # initialize the vector with 0s repeated as needed. start_loop <- 1 - } else { - # Initialize the vector with the intercept value repeated as needed. - response[,1,k] <- rep(values_k[1], nrow(newdata)) - start_loop <- 2 } # Loop over all terms (labels) except the intercept @@ -100,7 +105,7 @@ eval_poly <- function(poly, newdata, monomials = FALSE) { # In case the intercept has been moved, we reorder it to its original # position so it preserves the original notation of the user. response[,,k] <- reorder_intercept_in_monomials(response[,,k], - intercept_position, + original_intercept_pos_for_reordering, n_sample) } @@ -175,39 +180,45 @@ preprocess_newdata <- function(newdata){ #' #' @noRd preprocess_poly <- function(poly){ - - # If values is a single vector, transform into matrix if (!is.matrix(poly$values)){ poly$values <- as.matrix(poly$values) } + intercept_position_original <- NULL # To return the original position - # In case there is no intercept, set a NULL value - intercept_position <- NULL - - # If there is intercept and it is not the first element, reorder the - # polynomial labels and values - if (c(0) %in% poly$labels){ + # Find if intercept c(0) exists + idx_intercept_in_list <- which(sapply(poly$labels, function(x) length(x)==1 && x[1]==0)) - intercept_position <- which(sapply(poly$labels, function(x) c(0) %in% x)) + if (length(idx_intercept_in_list) > 0) { # Intercept c(0) exists + intercept_position_original <- idx_intercept_in_list[1] # Take the first one if multiple - if (intercept_position != 1){ + if (intercept_position_original != 1) { # If it's not already first + # Store the intercept label and value + intercept_label_vec <- poly$labels[[intercept_position_original]] # Should be c(0) + intercept_value_row <- poly$values[intercept_position_original, , drop = FALSE] - # Store the value - intercept_value <- poly$values[intercept_position,] - - # Remove label and value - poly$labels <- poly$labels[-intercept_position] - poly$values <- poly$values[-intercept_position, , drop = FALSE] - - # Add label and value back at start of list - poly$labels <- append(poly$labels, c(0), after=0) - poly$values <- unname(rbind(intercept_value, poly$values)) + # Remove it from original position + poly$labels <- poly$labels[-intercept_position_original] + poly$values <- poly$values[-intercept_position_original, , drop = FALSE] + # Prepend it + poly$labels <- c(list(intercept_label_vec), poly$labels) # Correctly prepend list element + poly$values <- rbind(intercept_value_row, poly$values) } } + # At this point, if intercept c(0) existed, it's now the first element. + # If it didn't exist, poly is unchanged. + # The variable to return indicating original pos (for reorder_intercept_in_monomials) + # should reflect the original position, not just if it was moved. output <- list() - output$intercept_position <- intercept_position + # For reorder_intercept_in_monomials, we need to know if an intercept *was* present and its *original* slot + # The current `intercept_position` in `eval_poly` seems to be used to determine if an intercept *is now first*. + # Let's stick to your variable names. `intercept_position` will be used by `eval_poly` to check if the first term is an intercept. + # And by `reorder_intercept_in_monomials` to know where it *originally* was. + + current_first_is_intercept <- length(poly$labels) > 0 && length(poly$labels[[1]]) == 1 && poly$labels[[1]][1] == 0 + + output$intercept_position <- if(length(idx_intercept_in_list) > 0) idx_intercept_in_list[1] else NULL output$poly <- poly return(output) } diff --git a/R/nn2poly_methods.R b/R/nn2poly_methods.R index ef37616..fa18387 100644 --- a/R/nn2poly_methods.R +++ b/R/nn2poly_methods.R @@ -172,283 +172,1247 @@ predict.nn2poly <- function(object, #' Plot method for \code{nn2poly} objects. #' -#' A function that takes a polynomial (or several ones) as given by the -#' \pkg{nn2poly} algorithm, and then plots their absolute magnitude as barplots -#' to be able to compare the most important coefficients. +#' Provides various plots for \code{nn2poly} objects. #' -#' @param x A \code{nn2poly} object, as returned by the \pkg{nn2poly} algorithm. -#' @param ... Ignored. -#' @param n An integer denoting the number of coefficients to be plotted, -#' after ordering them by absolute magnitude. +#' @param x A \code{nn2poly} object. +#' @param type A string for plot type: "bar" (coefficient magnitudes), +#' "heatmap" (second-order terms), "local_contributions" (feature +#' contributions for a single observation), or "beeswarm" (summary of +#' term contributions across observations). +#' @param ... Additional arguments passed to specific plot types. #' -#' @return A plot showing the \code{n} most important coefficients. +#' @param n For `type = "bar"`, the number of top coefficients to plot. #' -#' @details -#' The plot method represents only the polynomials at the final layer, even if -#' `x` is generated using `nn2poly()` with `keep_layers=TRUE`. -#' -#' @examples -#' # --- Single polynomial output --- -#' # Build a NN structure with random weights, with 2 (+ bias) inputs, -#' # 4 (+bias) neurons in the first hidden layer with "tanh" activation -#' # function, 4 (+bias) neurons in the second hidden layer with "softplus", -#' # and 2 "linear" output units -#' -#' weights_layer_1 <- matrix(rnorm(12), nrow = 3, ncol = 4) -#' weights_layer_2 <- matrix(rnorm(20), nrow = 5, ncol = 4) -#' weights_layer_3 <- matrix(rnorm(5), nrow = 5, ncol = 1) -#' -#' # Set it as a list with activation functions as names -#' nn_object = list("tanh" = weights_layer_1, -#' "softplus" = weights_layer_2, -#' "linear" = weights_layer_3) -#' -#' # Obtain the polynomial representation (order = 3) of that neural network -#' final_poly <- nn2poly(nn_object, max_order = 3) -#' -#' # Plot all the coefficients, one plot per output unit -#' plot(final_poly) -#' -#' # Plot only the 5 most important coeffcients (by absolute magnitude) -#' # one plot per output unit -#' plot(final_poly, n = 5) +#' @param newdata_monomials For `type = "local_contributions"` or `"beeswarm"`, +#' the output of `predict(x, newdata, monomials = TRUE)`. This should be +#' for a single observation for "local_contributions" (or specify `observation_index`) +#' and for multiple observations for "beeswarm". +#' @param observation_index For `type = "local_contributions"`, the row index +#' from `newdata_monomials` to plot (default: 1). +#' @param poly_output_index For `type = "local_contributions"` or `"beeswarm"`, +#' if `x` produces multiple polynomial outputs, which one to plot (default: 1). +#' @param variable_names For `type = "local_contributions"` or `"beeswarm"`, +#' an optional character vector of original feature names to make term labels +#' more readable (e.g., "x1" instead of "1"). +#' @param max_order_to_display For `type = "local_contributions"`, the maximum +#' term order to show in the stacked bars (default: 3). #' -#' # --- Multiple output polynomials --- -#' # Build a NN structure with random weights, with 2 (+ bias) inputs, -#' # 4 (+bias) neurons in the first hidden layer with "tanh" activation -#' # function, 4 (+bias) neurons in the second hidden layer with "softplus", -#' # and 2 "linear" output units -#' -#' weights_layer_1 <- matrix(rnorm(12), nrow = 3, ncol = 4) -#' weights_layer_2 <- matrix(rnorm(20), nrow = 5, ncol = 4) -#' weights_layer_3 <- matrix(rnorm(10), nrow = 5, ncol = 2) +#' @param original_feature_data For `type = "beeswarm"`, a matrix or data frame +#' of the original predictor values for all observations in `newdata_monomials`. +#' Required for coloring points. +#' @param top_n_terms For `type = "beeswarm"`, an optional integer to display only +#' the top N most important terms (based on mean absolute monomial value). +#' @param min_order For `type = "bar"`, the minimum order of terms to include. +#' 0 (default) includes intercept (order 0 for this purpose) and all terms. +#' 1 excludes intercept, showing terms of polynomial degree 1+. +#' 2 excludes intercept and linear terms, showing terms of polynomial degree 2+. +#' @param feature_pair For `type = "interaction_surface"`, a numeric or character +#' vector of length 2 specifying the pair of features to plot. +#' @param grid_resolution For `type = "interaction_surface"`, the number of points +#' per dimension for the grid. +#' @param interaction_order_network For `type = "interaction_network"`, the effective +#' order of interactions to visualize (e.g., 2 for pairwise feature interactions). +#' @param metric_network For `type = "interaction_network"`, the metric for edge +#' weights ("coefficient_abs" or "mean_monomial_abs"). +#' @param top_n_interactions For `type = "interaction_network"`, number of top +#' (projected) pairwise interactions to display by weight. +#' @param layout_network For `type = "interaction_network"`, the layout algorithm +#' for ggraph (e.g., "nicely", "fr"). #' -#' # Set it as a list with activation functions as names -#' nn_object = list("tanh" = weights_layer_1, -#' "softplus" = weights_layer_2, -#' "linear" = weights_layer_3) +#' @return A ggplot object. +#' @export +plot.nn2poly <- function(x, type = "bar", ..., + # Args for bar plot + n = NULL, + min_order = 0, # New arg for bar plot + # Args for local_contributions & beeswarm plots + newdata_monomials = NULL, + poly_output_index = 1, + variable_names = NULL, + # Args for local_contributions plot + observation_index = 1, + max_order_to_display = 3, + # Args for beeswarm plot + original_feature_data = NULL, # Still needed for beeswarm + top_n_terms = NULL, # For beeswarm y-axis + # Args for interaction_surface plot + feature_pair = NULL, + grid_resolution = 20, + # Args for interaction_network plot + interaction_order_network = 2, + metric_network = "coefficient_abs", # Defaulted to coefficient_abs + top_n_interactions = NULL, + layout_network = "nicely" +) { + + if (length(class(x)) > 1) { + return(NextMethod()) + } + if (!requireNamespace("ggplot2", quietly = TRUE)) { + stop("Package 'ggplot2' is required for this functionality.", call. = FALSE) + } + + plot_object <- NULL + + # --- Common pre-processing for coefficient-based plots ("bar", "heatmap") --- + if (type %in% c("bar", "heatmap")) { + poly_for_plot <- NULL + if (is.null(x$values) && !is.null(x[[length(x)]]) && !is.null(x[[length(x)]][["output"]])) { + poly_for_plot <- x[[length(x)]][["output"]] + } else if (!is.null(x$values)) { + poly_for_plot <- x + } else { + stop("Input 'x' is not a recognized nn2poly object for plot type '", type, "'.", call. = FALSE) + } + if (is.vector(poly_for_plot$values)) { + poly_for_plot$values <- matrix(poly_for_plot$values, ncol = 1) + } + if (is.null(poly_for_plot$labels)) { + stop("Input 'x' is missing polynomial labels for plot type '", type, "'.", call. = FALSE) + } + + if (type == "bar") { + if (!is.null(list(...)$n) && is.null(n)) n <- list(...)$n + plot_object <- plot_bar(poly_for_plot, n = n, variable_names = variable_names, min_order = min_order) + } else if (type == "heatmap") { + plot_object <- plot_heatmap(poly_for_plot, variable_names = variable_names) + } + } + # --- Local Contributions Plot --- + else if (type == "local_contributions") { + if (is.null(newdata_monomials)) stop("For 'local_contributions', 'newdata_monomials' must be provided.", call. = FALSE) + if (is.null(x$labels)) stop("Input 'x' (nn2poly object) is missing labels for 'local_contributions'.", call. = FALSE) + + pred_dims <- dim(newdata_monomials) + if (is.null(pred_dims) || length(pred_dims) < 2 || length(pred_dims) > 3) stop("'newdata_monomials' must be a 2D or 3D array.", call. = FALSE) + num_obs_pred <- pred_dims[1]; num_terms_pred <- pred_dims[2]; num_poly_outputs_pred <- if (length(pred_dims) == 3) pred_dims[3] else 1 + if (observation_index < 1 || observation_index > num_obs_pred) stop(paste0("'observation_index' out of bounds."), call. = FALSE) + if (poly_output_index < 1 || poly_output_index > num_poly_outputs_pred) stop(paste0("'poly_output_index' out of bounds."), call. = FALSE) + if (num_terms_pred != length(x$labels)) stop("Term count mismatch between 'newdata_monomials' and 'x$labels'.", call. = FALSE) + + monomial_values_slice <- if (num_poly_outputs_pred == 1 && length(pred_dims) == 2) { + newdata_monomials[observation_index, ] + } else { + newdata_monomials[observation_index, , poly_output_index] + } + plot_object <- plot_local_contributions_internal( + poly_obj = x, monomial_values_for_obs = monomial_values_slice, + variable_names = variable_names, max_order_to_display = max_order_to_display + ) + } + # --- Beeswarm Plot --- + else if (type == "beeswarm") { + if (is.null(newdata_monomials)) stop("For 'beeswarm', 'newdata_monomials' must be provided.", call. = FALSE) + if (is.null(original_feature_data)) stop("For 'beeswarm', 'original_feature_data' must be provided for coloring.", call. = FALSE) + if (is.null(x$labels)) stop("Input 'x' (nn2poly object) is missing labels for 'beeswarm'.", call. = FALSE) + + pred_dims <- dim(newdata_monomials) + if (length(pred_dims) < 2 || length(pred_dims) > 3) stop("'newdata_monomials' must be a 2D or 3D array.", call. = FALSE) + num_obs_pred <- pred_dims[1]; num_terms_pred <- pred_dims[2]; num_poly_outputs_pred <- if (length(pred_dims) == 3) pred_dims[3] else 1 + + # Ensure original_feature_data is a matrix for consistent indexing + if (!is.matrix(original_feature_data)) { + original_feature_data_mat <- as.matrix(original_feature_data) + if(!is.numeric(original_feature_data_mat) && !is.logical(original_feature_data_mat)) { # Allow logicals to become 0/1 + stop("'original_feature_data' cannot be coerced to a numeric/logical matrix.", call. = FALSE) + } + } else { + original_feature_data_mat <- original_feature_data + } + + if (nrow(original_feature_data_mat) != num_obs_pred) stop("Row count mismatch: 'original_feature_data' and 'newdata_monomials'.", call. = FALSE) + if (poly_output_index < 1 || poly_output_index > num_poly_outputs_pred) stop(paste0("'poly_output_index' out of bounds."), call. = FALSE) + if (num_terms_pred != length(x$labels)) stop("Term count mismatch: 'newdata_monomials' and 'x$labels'.", call. = FALSE) + + monomial_values_for_plot <- if (num_poly_outputs_pred == 1 && length(pred_dims) == 2) { + newdata_monomials + } else { + newdata_monomials[, , poly_output_index] + } + + plot_object <- plot_beeswarm_internal( + poly_obj = x, all_monomial_values_for_output = monomial_values_for_plot, + original_feature_data = original_feature_data_mat, # ensure matrix form is passed + variable_names = variable_names, top_n_terms = top_n_terms + ) + } + + # --- Interaction Surface Plot --- + else if (type == "interaction_surface") { + if (is.null(feature_pair)) stop("For 'interaction_surface', 'feature_pair' must be provided.", call. = FALSE) + if (is.null(original_feature_data)) stop("For 'interaction_surface', 'original_feature_data' must be provided.", call. = FALSE) # original_feature_data for surface + if (is.null(x$labels) || is.null(x$values)) stop("Input 'x' is not a valid nn2poly object for 'interaction_surface'.", call. = FALSE) + + current_original_feature_data <- original_feature_data # Use the one passed for this plot type + plot_object <- plot_interaction_surface_internal( + poly_obj = x, feature_pair = feature_pair, + original_feature_data = current_original_feature_data, + grid_resolution = grid_resolution, variable_names = variable_names, + poly_output_index = poly_output_index + ) + } + # --- Interaction Network Plot --- + else if (type == "interaction_network") { + if (is.null(x$labels) || is.null(x$values)) stop("Input 'x' is not a valid nn2poly object for 'interaction_network'.", call. = FALSE) + + current_newdata_monomials <- newdata_monomials + if (metric_network == "mean_monomial_abs" && is.null(current_newdata_monomials)) { + stop("If metric_network is 'mean_monomial_abs', 'newdata_monomials' must be provided for 'interaction_network'.", call. = FALSE) + } + + plot_object <- plot_interaction_network_internal( + poly_obj = x, interaction_order = interaction_order_network, + metric = metric_network, newdata_monomials = current_newdata_monomials, + top_n_interactions = top_n_interactions, variable_names = variable_names, + poly_output_index = poly_output_index, layout = layout_network + ) + } else { + stop(paste0("Unknown plot type: '", type, "'. Available types are 'bar', 'heatmap', 'local_contributions', 'beeswarm', 'interaction_surface', 'interaction_network'."), call. = FALSE) + } + return(plot_object) +} + + +#' Format a term label for display #' -#' # Obtain the polynomial representation (order = 3) of that neural network -#' final_poly <- nn2poly(nn_object, max_order = 3) +#' @param term_label_vec A numeric vector representing the term, e.g., c(1), c(1,2), c(1,1). +#' @param variable_names Optional character vector of feature names. +#' @param use_product_format_for_named Logical. If TRUE and variable_names are provided, +#' use product format (e.g., "NameA*NameB"). Otherwise (or if variable_names is NULL), +#' uses comma-separated numeric indices (e.g., "1,2"). +#' @param use_product_format_for_numeric Logical. If TRUE and variable_names are NULL, +#' use product format for numeric indices (e.g., "1*2"). Defaults to FALSE (uses "1,2"). +#' @return A string representation of the term. +#' @noRd +format_term_label_display <- function(term_label_vec, + variable_names = NULL, + use_product_format_for_named = FALSE, + use_product_format_for_numeric = FALSE) { + if (length(term_label_vec) == 1 && term_label_vec[1] == 0) { + return("0") + } + if (is.null(term_label_vec) || length(term_label_vec) == 0) return("NA_term") + + parts <- character(length(term_label_vec)) + use_product_format_this_term <- FALSE + + if (!is.null(variable_names)) { + for (i in seq_along(term_label_vec)) { + idx <- term_label_vec[i] + if (idx > 0 && idx <= length(variable_names)) { + parts[i] <- variable_names[idx] + } else if (idx > 0) { # Index out of bounds for names, or 0 + parts[i] <- as.character(idx) # Fallback to numeric + } else { + parts[i] <- as.character(idx) # e.g. if a 0 slips through + } + } + if (use_product_format_for_named) { + use_product_format_this_term <- TRUE + } + } else { # No variable_names provided + parts <- as.character(term_label_vec) + if (use_product_format_for_numeric) { + use_product_format_this_term <- TRUE + } + } + + if (use_product_format_this_term) { + # For product format, it's common to sort parts for canonical representation, e.g., x1*x2 not x2*x1 + # However, c(1,1,2) should be x1*x1*x2, not x1*x2*x1. So simple sort is not enough. + # For now, just join them as they appear in the label. + # A more sophisticated approach might count and use powers, e.g., x1^2*x2 + return(paste(parts, collapse = "*")) + } else { + return(paste(parts, collapse = ",")) + } +} + +#' Internal function for plot.nn2poly bar type. #' -#' # Plot all the coefficients, one plot per output unit -#' plot(final_poly) +#' @param poly_obj A nn2poly object (or the filtered subset for plotting). +#' @param n Number of top coefficients to plot per polynomial output. +#' @param variable_names Optional character vector of feature names. +#' @param min_order Minimum order of terms to include (0=all including intercept, 1=degree 1+, etc.). +#' @return A ggplot object. +#' @noRd +plot_bar <- function(poly_obj, n = NULL, variable_names = NULL, min_order = 0) { + + # --- 1. Filter terms based on min_order --- + labels_to_keep <- list() + original_indices_to_keep <- integer(0) + + # Determine if poly_obj$values is a matrix (multiple outputs) or vector (single output) + # poly_obj$values should be terms x polynomial_outputs + is_multi_output <- !is.null(dim(poly_obj$values)) && ncol(poly_obj$values) > 1 + num_outputs_in_obj <- if (is.null(dim(poly_obj$values))) 1 else ncol(poly_obj$values) + + for (i in seq_along(poly_obj$labels)) { + term_lab <- poly_obj$labels[[i]] + + current_term_effective_order <- 0 # Default for intercept + if (length(term_lab) == 1 && term_lab[1] == 0) { + current_term_effective_order <- 0 # Intercept is order 0 + } else if (length(term_lab) > 0 && all(term_lab > 0)) { # All positive indices (variables) + current_term_effective_order <- length(term_lab) # Polynomial degree of the term + } else { + # Mixed term or unexpected label structure, assign high order to filter out unless min_order is very low + # Or, decide how to handle terms like c(0,1) if they were possible. + # For now, assume labels are either c(0) or vectors of positive integers. + # If not, this might need adjustment based on what format_term_label_display expects. + current_term_effective_order <- length(term_lab) # Fallback, usually non-zero if not intercept + } + + if (current_term_effective_order >= min_order) { + labels_to_keep[[length(labels_to_keep) + 1]] <- term_lab + original_indices_to_keep <- c(original_indices_to_keep, i) + } + } + + if (length(labels_to_keep) == 0) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = paste0("No terms found with order >= ", min_order, ".")) + + ggplot2::theme_minimal() + + ggplot2::ggtitle(paste0("Coefficients (Order >= ", min_order, ")"))) + } + + poly_obj_filtered <- list() + poly_obj_filtered$labels <- labels_to_keep + if (is_multi_output) { + poly_obj_filtered$values <- poly_obj$values[original_indices_to_keep, , drop = FALSE] + } else { # Single output case (poly_obj$values might have been a vector or 1-col matrix originally) + values_subset <- if(is.null(dim(poly_obj$values))) poly_obj$values[original_indices_to_keep] else poly_obj$values[original_indices_to_keep, 1] + poly_obj_filtered$values <- matrix(values_subset, ncol = 1) # Ensure it's a matrix + } + + # --- 2. Prepare data frame for plotting --- + # If no number of top coefficients (n) is provided, use all filtered coefficients. + if (is.null(n) || n > nrow(poly_obj_filtered$values)) { + n <- nrow(poly_obj_filtered$values) + } + if (n == 0) { # Case where n becomes 0 after filtering if nrow(poly_obj_filtered$values) was 0 + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = paste0("No terms to plot after filtering (n=0).")) + + ggplot2::theme_minimal() + + ggplot2::ggtitle(paste0("Coefficients (Order >= ", min_order, ")"))) + } + + + # poly_obj_filtered$values is now terms x polynomial_outputs + coefficients_matrix <- poly_obj_filtered$values + all_filtered_labels <- poly_obj_filtered$labels + n_polys <- ncol(coefficients_matrix) # Number of polynomial outputs to plot + + all_df_list <- vector("list", n_polys) + + for (r in 1:n_polys) { + poly_coeffs_for_output_r <- coefficients_matrix[, r] + + # Sort by absolute magnitude to get top n coefficients for this polynomial output + # Handle cases where n is larger than available coefficients (already done by n re-assignment above for nrow) + # Ensure consistent tie-breaking by original index if abs values are same + order_indices <- order(abs(poly_coeffs_for_output_r), seq_along(poly_coeffs_for_output_r), decreasing = TRUE) + + top_n_actual_indices <- order_indices[1:n] # Indices within poly_coeffs_for_output_r + + top_n_values_abs <- abs(poly_coeffs_for_output_r[top_n_actual_indices]) + top_n_signs <- sign(poly_coeffs_for_output_r[top_n_actual_indices]) + + # Handle exact zero coefficients' sign (map to positive) + top_n_signs[top_n_signs == 0] <- 1 + + list_labels_for_top_n <- all_filtered_labels[top_n_actual_indices] + + string_labels_for_top_n <- sapply(list_labels_for_top_n, function(lab) { + format_term_label_display(lab, variable_names, + use_product_format_for_named = FALSE, + use_product_format_for_numeric = FALSE) + }) + + df_poly_r <- data.frame( + term_name = string_labels_for_top_n, + sign = as.factor(top_n_signs), + value = top_n_values_abs, + poly_output_id = r, # To identify polynomial if faceting + stringsAsFactors = FALSE + ) + + # For reordering bars within facets (important if coord_flip is used) + # Order by value descending, so when flipped, largest is at top. + df_poly_r$term_display_order <- factor(df_poly_r$term_name, + levels = df_poly_r$term_name[order(df_poly_r$value, decreasing = TRUE)]) + + all_df_list[[r]] <- df_poly_r + } + + all_plot_df <- do.call(rbind, all_df_list) + + if (nrow(all_plot_df) == 0) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No coefficients to plot after processing.") + + ggplot2::theme_minimal() + + ggplot2::ggtitle(paste0("Coefficients (Order >= ", min_order, ")"))) + } + + # Ensure poly_output_id is a factor for faceting title with meaningful labels + all_plot_df$poly_facet_label <- factor(paste("Output Poly.", all_plot_df$poly_output_id)) + + # --- 3. Generate Plot --- + plot_title_text <- if (min_order == 0) "Most Important Coefficients (All Terms)" + else if (min_order == 1) "Most Important Coefficients (Order >= 1)" + else paste0("Most Important Interaction Coefficients (Order >= ", min_order, ")") + + p <- ggplot2::ggplot(all_plot_df, + ggplot2::aes(x = .data$term_display_order, # Use the reordered factor + y = .data$value, + fill = .data$sign)) + + ggplot2::geom_bar(stat = "identity", colour = "black", alpha = 1, width=0.7) + + + ggplot2::scale_fill_manual( + name = "Sign", + values = c("1" = "#00BA38", "-1" = "#F8766D"), # Your green/red colors + labels = c("1" = "+", "-1" = "-"), + drop = FALSE # Show all legend items even if one sign not present + ) + + ggplot2::labs(title = plot_title_text, + x = "Coefficient (absolute value)", + y = "Polynomial Term") + + + p <- p + ggplot2::theme_minimal() + + ggplot2::theme( + axis.text.x = ggplot2::element_text(size=10), + axis.text.y = ggplot2::element_text(size=10), + legend.position = "top", + legend.direction = "horizontal", + plot.title = ggplot2::element_text(hjust = 0.5, face="bold"), + panel.grid.major.y = ggplot2::element_blank(), + panel.grid.minor.y = ggplot2::element_blank(), + strip.background = ggplot2::element_rect(fill="grey90", linetype="blank"), + strip.text = ggplot2::element_text(face="bold") + ) + + + if (n_polys > 1) { + # For coord_flip, scales = "free_y" means different terms can appear per facet + # scales = "free_x" means y-axis (now value) can have different scales + p <- p + ggplot2::facet_wrap(~poly_facet_label, scales = "free_x") + } + + return(p) +} + + +#' Internal function for plot.nn2poly heatmap type. #' -#' # Plot only the 5 most important coeffcients (by absolute magnitude) -#' # one plot per output unit -#' plot(final_poly, n = 5) +#' @inheritParams plot.nn2poly #' -#' @export -plot.nn2poly <- function(x, ..., n=NULL) { - if (length(class(x)) > 1) - return(NextMethod()) +#' @return +#' @noRd +plot_heatmap <- function(poly_obj, variable_names = NULL) { # Changed x to poly_obj - if (!requireNamespace("ggplot2", quietly = TRUE)) { - stop("package 'ggplot2' is required for this functionality", call. = FALSE) + coefficients_matrix <- poly_obj$values # Matrix: terms x n_polynomial_outputs + all_labels_list <- poly_obj$labels # List of label vectors for each term + + # Filter for second-order term labels (length == 2, and positive variable indices) + is_second_order <- sapply(all_labels_list, function(lab) { + !is.null(lab) && length(lab) == 2 && all(lab > 0) && is.numeric(lab) + }) + + if (!any(is_second_order)) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No second-order terms with positive variable indices found.") + + ggplot2::theme_minimal() + + ggplot2::ggtitle("Second-Order Coefficients Heatmap")) } - if (!requireNamespace("patchwork", quietly = TRUE)) { - stop("package 'patchwork' is required for this functionality", call. = FALSE) + second_order_labels_subset <- all_labels_list[is_second_order] + second_order_values_matrix <- coefficients_matrix[is_second_order, , drop = FALSE] + + # Determine var_names for axes + present_vars <- unique(unlist(second_order_labels_subset)) + if (length(present_vars) == 0) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x=0.5, y=0.5, label="No variables found in second-order terms.") + + ggplot2::theme_minimal() + + ggplot2::ggtitle("Second-Order Coefficients Heatmap")) } + max_var_idx <- max(present_vars) - # a special case is needed for the case in which the polynomial was generated - # with `keep_layers = TRUE` + if(max_var_idx == 0){ # Should not happen if all(lab > 0) filter is effective + return(ggplot2::ggplot() + + ggplot2::annotate("text", x=0.5, y=0.5, label="Max variable index is 0.") + + ggplot2::theme_minimal() + + ggplot2::ggtitle("Second-Order Coefficients Heatmap")) + } - if (is.null(x$values)) { - x <- x[[length(x)]][["output"]] + axis_var_labels <- character(max_var_idx) + if (!is.null(variable_names)) { + for (i in 1:max_var_idx) { + if (i <= length(variable_names)) { + axis_var_labels[i] <- variable_names[i] + } else { + axis_var_labels[i] <- as.character(i) # Fallback + } + } + } else { + axis_var_labels <- as.character(1:max_var_idx) # Default: "1", "2", ... } - # Check if x$values is a vector and transform it into a column matrix - if (is.vector(x$values)){ - x$values <- matrix(x$values, ncol = 1) + n_polys <- ncol(second_order_values_matrix) + + all_heatmap_df_list <- vector("list", n_polys) + + for (p_idx in 1:n_polys) { + current_poly_coeffs <- second_order_values_matrix[, p_idx] + + heatmap_matrix <- matrix(0, nrow = max_var_idx, ncol = max_var_idx, + dimnames = list(axis_var_labels, axis_var_labels)) + + for (i in seq_along(second_order_labels_subset)) { + label <- second_order_labels_subset[[i]] # label is c(var1_idx, var2_idx) + value <- current_poly_coeffs[i] + + # Ensure indices are within bounds (should be, by construction of max_var_idx) + if (all(label <= max_var_idx)) { + if (label[1] == label[2]) { # Squared term: x_i^2 + heatmap_matrix[label[1], label[1]] <- value + } else { # Interaction term: x_i*x_j + heatmap_matrix[label[1], label[2]] <- value + heatmap_matrix[label[2], label[1]] <- value # Symmetric + } + } + } + + heatmap_df_poly <- as.data.frame(as.table(heatmap_matrix)) + colnames(heatmap_df_poly) <- c("Var1", "Var2", "Value") + # Ensure Var1 and Var2 are factors with levels in the correct order + heatmap_df_poly$Var1 <- factor(heatmap_df_poly$Var1, levels = axis_var_labels) + heatmap_df_poly$Var2 <- factor(heatmap_df_poly$Var2, levels = rev(axis_var_labels)) + + heatmap_df_poly$poly_index <- p_idx + all_heatmap_df_list[[p_idx]] <- heatmap_df_poly } - if (is.null(n)) { - n <- dim(x$values)[1] + combined_heatmap_df <- do.call(rbind, all_heatmap_df_list) + combined_heatmap_df$poly_facet_label <- factor(paste("Output Polynomial", combined_heatmap_df$poly_index)) + + # Generate the heatmap plot + final_heatmap_plot <- ggplot2::ggplot(combined_heatmap_df, ggplot2::aes(x = .data$Var1, y = .data$Var2, fill = .data$Value)) + + ggplot2::geom_tile(color = "white", na.rm = TRUE) + # na.rm in case some interactions are missing + ggplot2::scale_fill_gradient2(low = "#F8766D", mid = "white", high = "#00BA38", # User's colors + midpoint = 0, name = "Coefficient", na.value = "grey90") + + ggplot2::theme_minimal(base_size = 10) + + ggplot2::labs(x = "Variable", y = "Variable") + + ggplot2::theme( + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, vjust = 1), + axis.text.y = ggplot2::element_text(), # Default angle + panel.grid = ggplot2::element_blank(), + legend.position = "bottom", + plot.title = ggplot2::element_text(hjust = 0.5, size = 14), + strip.background = ggplot2::element_rect(fill="grey90", linetype="blank"), + strip.text = ggplot2::element_text(face="bold") + ) + + ggplot2::coord_fixed() # Ensures tiles are square + + if (n_polys > 1) { + final_heatmap_plot <- final_heatmap_plot + ggplot2::facet_wrap(~poly_facet_label) + } else { + final_heatmap_plot <- final_heatmap_plot + ggplot2::ggtitle("Second-Order Coefficients Heatmap") } - # Transpose values to be polynomials as rows instead of columns - # Needed to work as in previous nn2poly output format - M <- t(x$values) - all_labels <- x$labels - n_polys <- nrow(M) + return(final_heatmap_plot) +} - all_df <- data.frame() +# Helper to get max variable index (p) from a list of labels +get_max_var_index_from_labels <- function(labels_list) { + all_vars <- unlist(labels_list) + all_vars <- all_vars[all_vars > 0] # Exclude 0 (intercept) and any non-positive indices + if (length(all_vars) == 0) return(0) + return(max(all_vars, na.rm = TRUE)) +} - for (r in 1:n_polys) { - Mr <- M[r, ] - aux_total <- sort(abs(Mr), decreasing = TRUE, index.return = TRUE) - aux_values <- aux_total$x[1:n] - aux_index <- aux_total$ix[1:n] +#' Internal function for plot.nn2poly local_contributions type. +#' @noRd +plot_local_contributions_internal <- function(poly_obj, + monomial_values_for_obs, # Vector for one obs, one output + variable_names = NULL, + max_order_to_display = 3) { + + # --- Input Validation --- + if (!is.list(poly_obj) || is.null(poly_obj$labels)) { + stop("'poly_obj' must be a valid nn2poly object with labels.") + } + if (!is.numeric(monomial_values_for_obs) || !is.vector(monomial_values_for_obs)) { + stop("'monomial_values_for_obs' must be a numeric vector.") + } + if (length(poly_obj$labels) != length(monomial_values_for_obs)) { + stop("Length of 'poly_obj$labels' and 'monomial_values_for_obs' must match.") + } - # Obtain labels of chosen coefficients: - list_labels <- all_labels[aux_index] + # --- Calculate Contributions --- + contributions_list <- list() + for (i in seq_along(poly_obj$labels)) { + term_label <- poly_obj$labels[[i]] # e.g., c(1), c(1,2), c(1,1,2) + term_value_for_obs <- monomial_values_for_obs[i] - string_labels <- rep("0", n) - for (i in 1:n) { - # Create the label as a string of the form "l_1 l_2 ... l_t" - string_labels[i] <- paste(as.character(list_labels[[i]]), collapse = ",") + # Skip if monomial value is effectively zero + if (abs(term_value_for_obs) < .Machine$double.eps^0.75) { # More robust check for zero + next } - aux_sign <- sign(Mr)[aux_index] + # Skip intercept term (label c(0)) from feature attribution + if (length(term_label) == 1 && term_label[1] == 0) { + next + } - df <- data.frame( - name = string_labels, - sign = as.factor(aux_sign), - value = aux_values, - type = r - ) + current_term_order <- length(term_label) # Order of the term + + # Skip if term order is 0 (shouldn't happen if intercept is skipped) or exceeds display limit + if (current_term_order == 0 || current_term_order > max_order_to_display) { + next + } + + # Filter for actual variable indices (positive integers) within the term label + vars_in_term_label <- term_label[term_label > 0] - all_df <- rbind(all_df, df) + if (length(vars_in_term_label) == 0) { # No actual variables in this term + next + } + + # Total number of variable occurrences in the term (e.g., for c(1,1,2), this is 3) + total_var_occurrences_in_term <- length(vars_in_term_label) + + # Count occurrences of each unique variable in this specific term's label + # For c(1,1,2): table gives "1" -> 2, "2" -> 1 + var_counts_in_term <- table(vars_in_term_label) + + # Distribute the term's value proportionally to variable counts + for (var_idx_char in names(var_counts_in_term)) { + var_idx <- as.integer(var_idx_char) + count_this_var <- var_counts_in_term[[var_idx_char]] # How many times this var_idx appears in vars_in_term_label + + # Proportion for this variable in this term + proportion_for_this_var <- count_this_var / total_var_occurrences_in_term + + attributed_value <- term_value_for_obs * proportion_for_this_var + + contributions_list[[length(contributions_list) + 1]] <- + data.frame( + variable_idx = var_idx, + term_order_num = current_term_order, # Overall order of the term + contribution = attributed_value, + stringsAsFactors = FALSE + ) + } } - # If a coefficient is exactly 0, assign it to positive - if (any(all_df$sign == 0)){ - all_df$sign[which(all_df$sign==0)] = 1 + + if (length(contributions_list) == 0) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No feature contributions to display for this observation (or selected orders).") + + ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) } + plot_df <- do.call(rbind, contributions_list) + + # Aggregate contributions: sum up all contributions for each variable_idx and term_order_num + # E.g., if var 1 gets contribution from term c(1) [1st order] and from term c(1,1,2) [3rd order] + plot_df_agg <- aggregate(contribution ~ variable_idx + term_order_num, data = plot_df, FUN = sum) + # Filter out effectively zero aggregated contributions + plot_df_agg <- plot_df_agg[abs(plot_df_agg$contribution) > .Machine$double.eps^0.75, ] - # Define different scale for multiple or single sign cases. - if (all(levels(all_df$sign) == c("-1", "1"))){ - scale_values <- c("#F8766D", "#00BA38") - scale_labels <- c("-", "+") - } else if (levels(all_df$sign) == c("1")) { - scale_values <- c("#00BA38") - scale_labels <- c("+") - } else if (levels(all_df$sign) == c("-1")) { - scale_values <- c("#F8766D") - scale_labels <- c("-") + if (nrow(plot_df_agg) == 0) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "All aggregated feature contributions are negligible.") + + ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) } - # inspired by tidytext::reorder_within - new_x <- do.call(paste, c(list(all_df$name, sep = "___"), list(all_df$type))) - reorder_aux <- stats::reorder(new_x, all_df$value, FUN = mean, decreasing = TRUE) + # --- Prepare for Plotting --- + max_present_order <- min(max(plot_df_agg$term_order_num, na.rm = TRUE), max_order_to_display) - # inspired by tidytext::scale_x_reordered and tidtytext::reorder_func - reorder_func <- function(x, sep = "___") { - reg <- paste0(sep, ".+$") - gsub(reg, "", x) + if(is.infinite(max_present_order) || max_present_order < 1) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No contributions to display for the selected orders after aggregation.") + + ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) } + order_suffix <- function(k) { + if (k %% 10 == 1 && k %% 100 != 11) return("st") + if (k %% 10 == 2 && k %% 100 != 12) return("nd") + if (k %% 10 == 3 && k %% 100 != 13) return("rd") + return("th") + } + order_labels_vec <- sapply(1:max_present_order, function(o) paste0(o, order_suffix(o), " order")) + + plot_df_agg$term_order_str <- factor( + plot_df_agg$term_order_num, + levels = 1:max_present_order, + labels = order_labels_vec[1:max_present_order] # Ensure correct length of labels + ) - plot_all <- ggplot2::ggplot(all_df, - ggplot2::aes(x = reorder_aux, - y = .data$value, - fill = .data$sign)) + - ggplot2::geom_bar(stat = "identity", colour = "black", alpha = 1) + - ggplot2::scale_x_discrete(labels = reorder_func) + # Variable names for x-axis + p_model <- get_max_var_index_from_labels(poly_obj$labels) + p_data <- if(nrow(plot_df_agg) > 0) max(plot_df_agg$variable_idx, na.rm = TRUE) else 0 + p <- max(c(0, p_model, p_data), na.rm = TRUE) + if (p == 0 && nrow(plot_df_agg) > 0) { p <- max(plot_df_agg$variable_idx, na.rm = TRUE) } + if (p == 0) { + return(ggplot2::ggplot() + ggplot2::annotate("text", x=0.5, y=0.5, label="No variables found in contributions.") + + ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) + } - if (n_polys >1){ - plot_all <- plot_all + ggplot2::facet_wrap(~type, scales = "free_x") + current_axis_var_labels <- character(p) # Full list of potential labels + if (!is.null(variable_names)) { + for (i in 1:p) { + if (i <= length(variable_names)) { + current_axis_var_labels[i] <- variable_names[i] + } else { + current_axis_var_labels[i] <- as.character(i) # Fallback + } + } + } else { + current_axis_var_labels <- as.character(1:p) # Default: "1", "2", ... } - plot_all <- plot_all + - cowplot::theme_half_open() + - ggplot2::labs(y = "Coefficient (absolute) values", x = "Variables or interactions") + - ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, vjust = 0.5, hjust = 1)) + - ggplot2::scale_fill_manual(values = scale_values, labels = scale_labels) + - ggplot2::theme(legend.direction = "horizontal") + - ggplot2::labs(fill = "Sign") + unique_var_indices_in_data <- sort(unique(plot_df_agg$variable_idx)) + plot_df_agg$variable_label <- factor(plot_df_agg$variable_idx, + levels = unique_var_indices_in_data, + labels = current_axis_var_labels[unique_var_indices_in_data]) + # --- Generate Plot --- + final_plot <- ggplot2::ggplot(plot_df_agg, + ggplot2::aes(x = .data$variable_label, + y = .data$contribution, + fill = .data$term_order_str)) + + ggplot2::geom_col(position = "stack", width = 0.7, na.rm = TRUE) + # na.rm for safety + ggplot2::geom_hline(yintercept = 0, linetype = "solid", color = "black") + + # ggplot2::scale_fill_manual(values = active_colors, name = "Term Order", drop = FALSE) + # drop=FALSE ensures legend consistency + ggplot2::labs(title = "Local Feature Contributions", + x = "Feature", + y = "Contribution to Prediction") + + ggplot2::theme_minimal(base_size = 11) + + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, vjust = 1), + legend.position = "top", + plot.title = ggplot2::element_text(hjust = 0.5)) - return(plot_all) + return(final_plot) } -#' Heatmap for Second-Order Polynomial Terms -#' -#' This function generates a heatmap to visualize second-order terms in the polynomial representation of a neural network. -#' It displays squared terms along the diagonal and pairwise interactions off-diagonal, with colors indicating the magnitude and sign of the coefficients. -#' -#' @param x A `nn2poly` object, as returned by the `nn2poly` algorithm. -#' @param ... Additional arguments (unused). -#' @param max_order Integer, defaults to 2. This function currently supports only second-order terms. -#' -#' @return A ggplot object showing the heatmap of second-order terms. -#' -#' @details -#' Coefficients are displayed as a gradient from red (negative) to green (positive), with white indicating zero. -#' The diagonal contains squared terms (e.g., \eqn{x_1^2}), while off-diagonal entries represent pairwise interactions (e.g., \eqn{x_1x_2}). -#' -#' @examples -#' # Example: Single output polynomial with 20 variables -#' set.seed(42) -#' weights_layer_1 <- matrix(rnorm(21 * 10), nrow = 21, ncol = 10) -#' weights_layer_2 <- matrix(rnorm(11 * 1), nrow = 11, ncol = 1) -#' -#' nn_object <- list("softplus" = weights_layer_1, "linear" = weights_layer_2) -#' -#' # Generate the polynomial representation -#' final_poly <- nn2poly(nn_object, max_order = 2) -#' -#' # Plot the heatmap for second-order terms -#' heatmap.nn2poly(final_poly) -#' -#' -heatmap.nn2poly <- function(x, ..., max_order = 2) { - if (!requireNamespace("ggplot2", quietly = TRUE)) { - stop("package 'ggplot2' is required for this functionality", call. = FALSE) +#' Internal function for plot.nn2poly beeswarm type. +#' @noRd +plot_beeswarm_internal <- function(poly_obj, + all_monomial_values_for_output, # 2D matrix: observations x terms + original_feature_data, # Full matrix/data frame of original predictors + variable_names = NULL, + top_n_terms = NULL) { + + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("Package 'ggbeeswarm' is required for the 'beeswarm' plot type.\n", + "Please install it using: install.packages('ggbeeswarm')", call. = FALSE) } - if (max_order != 2) { - stop("This function currently supports only second-order terms (max_order = 2).", call. = FALSE) + # --- Input Validation & Basic Setup --- + if (!is.list(poly_obj) || is.null(poly_obj$labels)) { + stop("'poly_obj' must be a valid nn2poly object with labels.") + } + if (!is.matrix(all_monomial_values_for_output) || ncol(all_monomial_values_for_output) != length(poly_obj$labels)) { + stop("'all_monomial_values_for_output' must be a matrix with columns matching poly_obj$labels.") } + num_obs_mono <- nrow(all_monomial_values_for_output) - # Ensure the object is an nn2poly object - if (length(class(x)) > 1) { - return(NextMethod()) + # Ensure original_feature_data is a matrix + if (!is.matrix(original_feature_data)) { + original_feature_data <- as.matrix(original_feature_data) + if (!is.numeric(original_feature_data)) { + stop("'original_feature_data' could not be coerced to a numeric matrix.") + } + } + num_obs_orig <- nrow(original_feature_data) + num_features_orig <- ncol(original_feature_data) + + if (num_obs_mono != num_obs_orig) { + stop("Number of observations in 'all_monomial_values_for_output' (", num_obs_mono, + ") must match 'original_feature_data' (", num_obs_orig, ").") } - # Extract polynomial coefficients and labels - coefficients <- x$values - labels <- x$labels + term_indices_to_plot <- 1:length(poly_obj$labels) + term_labels_raw <- poly_obj$labels - if (is.null(coefficients) || is.null(labels)) { - stop("Invalid nn2poly object: missing coefficients or labels.", call. = FALSE) + # --- Exclude Intercept --- + intercept_idx <- which(sapply(term_labels_raw, function(lab) length(lab) == 1 && lab[1] == 0)) + if (length(intercept_idx) > 0) { + term_indices_to_plot <- setdiff(term_indices_to_plot, intercept_idx) + if (length(term_indices_to_plot) == 0) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "Only an intercept term found. No beeswarm plot to generate.") + + ggplot2::theme_minimal() + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) + } + } + current_labels_raw_no_intercept <- term_labels_raw[term_indices_to_plot] + current_monomial_values_no_intercept <- all_monomial_values_for_output[, term_indices_to_plot, drop = FALSE] + + if (ncol(current_monomial_values_no_intercept) == 0) { + return(ggplot2::ggplot() + ggplot2::annotate("text",x=0.5,y=0.5,label="No non-intercept terms to plot.") + + ggplot2::theme_minimal() + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) } - if (is.vector(coefficients)) { - coefficients <- matrix(coefficients, ncol = 1) + # --- Term Importance and Selection --- + mean_abs_monomial_vals_all <- colMeans(abs(current_monomial_values_no_intercept), na.rm = TRUE) + + if (!is.null(top_n_terms) && top_n_terms > 0 && top_n_terms < ncol(current_monomial_values_no_intercept)) { + term_order_indices <- order(mean_abs_monomial_vals_all, decreasing = TRUE) + selected_indices_in_no_intercept <- term_order_indices[1:top_n_terms] + + final_labels_to_plot <- current_labels_raw_no_intercept[selected_indices_in_no_intercept] + final_monomial_values_to_plot <- current_monomial_values_no_intercept[, selected_indices_in_no_intercept, drop = FALSE] + mean_abs_for_selected_terms <- mean_abs_monomial_vals_all[selected_indices_in_no_intercept] + } else { + final_labels_to_plot <- current_labels_raw_no_intercept + final_monomial_values_to_plot <- current_monomial_values_no_intercept + mean_abs_for_selected_terms <- mean_abs_monomial_vals_all } - # Prepare data for second-order terms - second_order_indices <- which(sapply(labels, function(label) length(label) <= max_order)) - second_order_labels <- labels[second_order_indices] - second_order_values <- coefficients[second_order_indices, , drop = FALSE] + y_axis_order_indices <- order(mean_abs_for_selected_terms, decreasing = FALSE) + + # --- Prepare Data for Plotting (Long Format) --- + plot_df_list <- list() + term_strings_for_plot <- character(length(final_labels_to_plot)) - # Create a matrix for the heatmap - max_vars <- max(unlist(second_order_labels)) - heatmap_matrix <- matrix(0, nrow = max_vars, ncol = max_vars, dimnames = list(paste0("x", 1:max_vars), paste0("x", 1:max_vars))) + for (j_idx in seq_along(final_labels_to_plot)) { + term_lab_vector <- final_labels_to_plot[[j_idx]] # e.g. c(1), c(1,2) - for (i in seq_along(second_order_labels)) { - label <- second_order_labels[[i]] - value <- second_order_values[i, 1] + term_str <- format_term_label_display(term_lab_vector, variable_names, use_product_format_for_named = TRUE) + term_strings_for_plot[j_idx] <- term_str - if (length(label) == 1) { - # Single variable squared - heatmap_matrix[label, label] <- value - } else if (length(label) == 2) { - # Interaction terms - heatmap_matrix[label[1], label[2]] <- value - heatmap_matrix[label[2], label[1]] <- value + # Calculate the product of feature values for coloring + # P(X) part of M = c * P(X) + feature_product_for_coloring <- rep(1, num_obs_mono) # Start with 1 for product + valid_term <- TRUE + if (length(term_lab_vector) > 0) { # Should always be true if not intercept + for (var_idx_in_term in term_lab_vector) { + if (var_idx_in_term > 0 && var_idx_in_term <= num_features_orig) { + feature_product_for_coloring <- feature_product_for_coloring * original_feature_data[, var_idx_in_term] + } else if (var_idx_in_term == 0) { # Should not happen, intercept excluded + # This term involves intercept for coloring? + # For simplicity, if a 0 is in term_lab_vector (should not be), product is 0 or 1 depending on interpretation. + # Here, we are assuming term_lab_vector only contains positive variable indices. + } else { + # Variable index in term is out of bounds for original_feature_data + warning(paste0("Variable index ", var_idx_in_term, " in term '", term_str, + "' is out of bounds for 'original_feature_data' (max col: ", num_features_orig, + "). Coloring for this term might be inaccurate."), call. = FALSE) + feature_product_for_coloring <- rep(NA, num_obs_mono) # Make coloring NA + valid_term <- FALSE + break + } + } + } else { # Empty term label (should not happen) + feature_product_for_coloring <- rep(NA, num_obs_mono) + valid_term <- FALSE } + + + plot_df_list[[j_idx]] <- data.frame( + term_label_str = term_str, + monomial_value = final_monomial_values_to_plot[, j_idx], + coloring_value = if(valid_term) feature_product_for_coloring else NA_real_, # Use the calculated product + stringsAsFactors = FALSE + ) } - # Convert matrix to long format for ggplot2 - heatmap_df <- as.data.frame(as.table(heatmap_matrix)) - colnames(heatmap_df) <- c("Var1", "Var2", "Value") + plot_df_long <- do.call(rbind, plot_df_list) - # Generate the heatmap plot - heatmap_plot <- ggplot2::ggplot(heatmap_df, ggplot2::aes(x = Var1, y = Var2, fill = Value)) + - ggplot2::geom_tile(color = "white") + - ggplot2::scale_fill_gradient2(low = "#F8766D", mid = "white", high = "#00BA38", midpoint = 0, - name = "Coefficient") + + if (nrow(plot_df_long) == 0) { + return(ggplot2::ggplot() + + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No data to plot after filtering terms.") + + ggplot2::theme_minimal() + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) + } + + plot_df_long$term_label_str <- factor(plot_df_long$term_label_str, + levels = term_strings_for_plot[y_axis_order_indices]) + + # Legend title for coloring + color_legend_title <- "Feature Product Value" + + + # --- Generate Plot --- + beeswarm_plot <- ggplot2::ggplot(plot_df_long, + ggplot2::aes(x = .data$monomial_value, + y = .data$term_label_str, + color = .data$coloring_value)) + # Changed to coloring_value + ggplot2::geom_vline(xintercept = 0, linetype = "dashed", color = "grey50") + + ggbeeswarm::geom_quasirandom(alpha = 0.7, size = 1.5, shape = 16, groupOnX = FALSE, na.rm = TRUE) + + ggplot2::scale_color_gradient(low = "blue", high = "red", name = color_legend_title, na.value = "grey70") + + ggplot2::labs(title = "Term Contribution Beeswarm Plot", + x = "Monomial Term Value (Contribution)", + y = "Polynomial Term") + + ggplot2::theme_minimal(base_size = 11) + + ggplot2::theme(legend.position = "right", + plot.title = ggplot2::element_text(hjust = 0.5)) + + return(beeswarm_plot) +} + + + +# Helper to identify and sum relevant terms for the surface plot +# eval_selected_terms_on_surface <- function(poly_obj, # Should be single output (values is a vector or 1-col matrix) +# newdata_row, # Single row of data for all p features +# feature_pair_indices) { # Numeric indices of the two features +# value_sum <- 0 +# +# for (k_term in seq_along(poly_obj$labels)) { +# term_label <- poly_obj$labels[[k_term]] +# term_coeff <- poly_obj$values[k_term] # Already single output +# +# # Identify variables in the current term label (excluding 0 for intercept) +# vars_in_current_label <- term_label[term_label > 0] +# +# is_relevant <- FALSE +# if (length(term_label) == 1 && term_label[1] == 0) { # Intercept term +# is_relevant <- TRUE +# } else if (length(vars_in_current_label) > 0 && # Has some variables +# all(vars_in_current_label %in% feature_pair_indices)) { # All variables in term are from the chosen pair +# is_relevant <- TRUE +# } +# +# if (is_relevant) { +# # Evaluate this single term +# # Create a temporary mini-polynomial for this one term +# mini_poly <- list(labels = list(term_label), values = matrix(term_coeff, ncol = 1)) +# value_sum <- value_sum + eval_poly(poly = mini_poly, newdata = newdata_row) +# } +# } +# return(value_sum) +# } + + +eval_selected_terms_on_surface <- function(poly_obj, + newdata_row, + feature_pair_indices) { + value_sum <- 0 + + # Crucial: What is the structure of poly_obj here? + # poly_obj is poly_obj_single_output + # poly_obj$labels is a list of all original labels + # poly_obj$values is a matrix with N_terms rows and 1 column. + + for (k_term in seq_along(poly_obj$labels)) { + term_label <- poly_obj$labels[[k_term]] + term_coeff <- poly_obj$values[k_term, 1] # Ensure it's a scalar from the single column + + vars_in_current_label <- term_label[term_label > 0] + is_relevant <- FALSE + if (length(term_label) == 1 && term_label[1] == 0) { + is_relevant <- TRUE + } else if (length(vars_in_current_label) > 0 && + all(vars_in_current_label %in% feature_pair_indices)) { + is_relevant <- TRUE + } + + if (is_relevant) { + mini_poly <- list(labels = list(term_label), values = matrix(term_coeff, ncol = 1)) + + # This parts gets errors many times, should review it. + current_term_value <- tryCatch({ + eval_poly(poly = mini_poly, newdata = newdata_row) + }, error = function(e) { + return(NA_real_) # Allow loop to continue + }) + + if (!is.na(current_term_value)) { + value_sum <- value_sum + current_term_value + } + } + } + return(value_sum) +} + + +#' Internal function for plot.nn2poly interaction_surface type +#' @noRd +plot_interaction_surface_internal <- function(poly_obj, + feature_pair, # Vector of two indices or names + original_feature_data, # For ranges and means + grid_resolution = 20, + variable_names = NULL, + poly_output_index = 1) { + + if (!is.matrix(original_feature_data)) original_feature_data <- as.matrix(original_feature_data) + num_total_features <- ncol(original_feature_data) + + feat_idx1 <- NULL; feat_idx2 <- NULL + feat_name1 <- ""; feat_name2 <- "" + + if (is.character(feature_pair) && length(feature_pair) == 2) { + default_names_for_lookup <- paste0("x", 1:num_total_features) # if variable_names is NULL + current_var_names <- if (!is.null(variable_names)) variable_names else default_names_for_lookup + if (length(current_var_names) < num_total_features && is.null(variable_names)) { + # This case should not happen if default_names_for_lookup is used correctly. + # It might happen if user provides partial variable_names. + # We should ensure current_var_names has length num_total_features for matching. + # For simplicity, assume variable_names, if provided, is complete. + } + + feat_idx1 <- match(feature_pair[1], current_var_names) + feat_idx2 <- match(feature_pair[2], current_var_names) + if (is.na(feat_idx1) || is.na(feat_idx2)) stop("One or both features in 'feature_pair' not found.", call. = FALSE) + feat_name1 <- feature_pair[1] + feat_name2 <- feature_pair[2] + } else if (is.numeric(feature_pair) && length(feature_pair) == 2) { + feat_idx1 <- as.integer(feature_pair[1]) + feat_idx2 <- as.integer(feature_pair[2]) + if (feat_idx1 < 1 || feat_idx1 > num_total_features || feat_idx2 < 1 || feat_idx2 > num_total_features) { + stop("Numeric 'feature_pair' indices are out of bounds.", call. = FALSE) + } + feat_name1 <- if (!is.null(variable_names) && feat_idx1 <= length(variable_names)) variable_names[feat_idx1] else as.character(feat_idx1) + feat_name2 <- if (!is.null(variable_names) && feat_idx2 <= length(variable_names)) variable_names[feat_idx2] else as.character(feat_idx2) + } else { + stop("'feature_pair' must be a character or numeric vector of length 2.", call. = FALSE) + } + + if (feat_idx1 == feat_idx2) stop("Features in 'feature_pair' must be different.", call. = FALSE) + feature_pair_indices <- c(feat_idx1, feat_idx2) + + range1 <- range(original_feature_data[, feat_idx1], na.rm = TRUE) + range2 <- range(original_feature_data[, feat_idx2], na.rm = TRUE) + grid1 <- seq(range1[1], range1[2], length.out = grid_resolution) + grid2 <- seq(range2[1], range2[2], length.out = grid_resolution) + + surface_data_df <- expand.grid(Feat1_Val_Plot = grid1, Feat2_Val_Plot = grid2) + surface_data_df$CombinedEffect <- NA_real_ + + base_newdata_row <- matrix(NA, nrow = 1, ncol = num_total_features) + other_indices <- setdiff(1:num_total_features, feature_pair_indices) + if (length(other_indices) > 0) { + for (k_idx in other_indices) { + base_newdata_row[1, k_idx] <- mean(original_feature_data[, k_idx], na.rm = TRUE) + } + } + + # Prepare a version of poly_obj with only the selected output's coefficients + poly_obj_single_output <- poly_obj + poly_obj_single_output$values <- matrix(poly_obj$values[, poly_output_index], ncol = 1) + + + for (k_row in 1:nrow(surface_data_df)) { + current_eval_row <- base_newdata_row + current_eval_row[1, feat_idx1] <- surface_data_df$Feat1_Val_Plot[k_row] + current_eval_row[1, feat_idx2] <- surface_data_df$Feat2_Val_Plot[k_row] + + # Add colnames if original data had them, for eval_poly robustness + if(!is.null(colnames(original_feature_data))) { + colnames(current_eval_row) <- colnames(original_feature_data) + } else if (!is.null(variable_names) && length(variable_names) == num_total_features) { + colnames(current_eval_row) <- variable_names # Use provided names if no colnames on data + } + + + surface_data_df$CombinedEffect[k_row] <- eval_selected_terms_on_surface( + poly_obj = poly_obj_single_output, + newdata_row = current_eval_row, + feature_pair_indices = feature_pair_indices + ) + } + + p <- ggplot2::ggplot(surface_data_df, ggplot2::aes(x = .data$Feat1_Val_Plot, y = .data$Feat2_Val_Plot, fill = .data$CombinedEffect)) + + ggplot2::geom_tile() + + ggplot2::scale_fill_viridis_c(name = "Summed Effect") + + ggplot2::labs(title = paste("Interaction Surface:", feat_name1, "&", feat_name2), + x = feat_name1, y = feat_name2) + ggplot2::theme_minimal() + - ggplot2::labs(x = "Variables", y = "Variables", title = "Second-Order Terms Heatmap") + - ggplot2::theme( - axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), - axis.text.y = ggplot2::element_text(size = 10) + ggplot2::coord_equal() # Often good for surfaces + + return(p) +} + + +#' Internal function for plot.nn2poly interaction_network type +#' @noRd +plot_interaction_network_internal <- function(poly_obj, + interaction_order = 2, # Target order to display + metric = "coefficient_abs", # "coefficient_abs", "mean_monomial_abs" + newdata_monomials = NULL, + top_n_interactions = NULL, + variable_names = NULL, + poly_output_index = 1, + layout = "nicely") { + + if (!requireNamespace("igraph", quietly = TRUE) || !requireNamespace("ggraph", quietly = TRUE)) { + stop("Packages 'igraph' and 'ggraph' are required for 'interaction_network' plot type.", call. = FALSE) + } + if (metric == "mean_monomial_abs" && is.null(newdata_monomials)) { + stop("If metric is 'mean_monomial_abs', 'newdata_monomials' must be provided.", call. = FALSE) + } + + edges_collector <- list() + all_nodes_involved <- integer(0) + + # Use coefficients from the selected output polynomial + current_poly_coeffs <- poly_obj$values[, poly_output_index, drop = TRUE] + + # Prepare monomial slice if needed + monomial_slice_for_metric <- NULL + if (metric == "mean_monomial_abs") { + monomial_slice_for_metric <- if (length(dim(newdata_monomials)) == 3) { + newdata_monomials[,,poly_output_index, drop = FALSE] + } else { + newdata_monomials + } + } + + + for (i_term in seq_along(poly_obj$labels)) { + term_label <- poly_obj$labels[[i_term]] + vars_in_label <- unique(term_label[term_label > 0]) # Unique positive variables in this term + + # Effective order of the interaction (count of unique variables) + # This is different from length(term_label) if there are squared terms like c(1,1) + effective_interaction_order <- length(vars_in_label) + + if (effective_interaction_order == interaction_order && interaction_order >= 2) { + term_coeff_val <- current_poly_coeffs[i_term] + term_metric_val <- NA + + if (metric == "coefficient_abs") { + term_metric_val <- abs(term_coeff_val) + } else if (metric == "mean_monomial_abs") { + term_metric_val <- mean(abs(monomial_slice_for_metric[, i_term]), na.rm = TRUE) + } + + if (!is.na(term_metric_val) && term_metric_val > 1e-9) { # If metric is non-zero + # For any N-th order interaction, create pairwise edges between all involved unique vars + if (length(vars_in_label) >= 2) { + pairs <- utils::combn(vars_in_label, 2, simplify = FALSE) + # Weight attribution: for an N-way interaction, attribute its strength to all constituent pairs. + # Simple approach: each pair gets a portion of the strength. + # More direct: each pair simply *is* part of this N-way interaction. + # Let's use the term_metric_val for each projected edge for now. + # The `sign` will be the sign of the original N-way term's coefficient. + + for (p in pairs) { + edges_collector[[length(edges_collector) + 1]] <- data.frame( + from = min(p), # Ensure canonical order + to = max(p), + weight = term_metric_val, # Could divide by choose(length(vars_in_label), 2) + sign = sign(term_coeff_val), + original_term_label_str = format_term_label_display(term_label, variable_names, use_product_format_for_named = TRUE) # For tooltip/info + ) + all_nodes_involved <- c(all_nodes_involved, p) + } + } + } + } + } + + if (length(edges_collector) == 0) { + return(ggplot2::ggplot() + ggplot2::annotate("text", x=0.5,y=0.5,label=paste0("No ", interaction_order, "-order interactions found or all have zero weight.")) + ggplot2::theme_minimal()) + } + + edges_df <- do.call(rbind, edges_collector) + + # Aggregate edges: if the same pair (from, to) results from multiple N-way terms, sum their weights. + # This means a pair frequently involved in strong N-way interactions gets a higher cumulative weight. + # Keep the sign of the component with the largest absolute weight for that pair. + + # To correctly get sign after aggregation: + edges_df_agg <- do.call(rbind, by(edges_df, list(edges_df$from, edges_df$to), function(sub_df) { + max_abs_weight_idx <- which.max(abs(sub_df$weight * sub_df$sign)) # Use signed weight to find max impact + data.frame( + from = sub_df$from[1], + to = sub_df$to[1], + weight = sum(sub_df$weight, na.rm = TRUE), # Sum of absolute contributions (strength) + sign = sub_df$sign[max_abs_weight_idx[1]], # Sign of the strongest contributor + # Could also list original_term_label_str contributing + stringsAsFactors = FALSE ) + })) + rownames(edges_df_agg) <- NULL + + + if (!is.null(top_n_interactions) && top_n_interactions > 0 && nrow(edges_df_agg) > top_n_interactions) { + edges_df_agg <- edges_df_agg[order(edges_df_agg$weight, decreasing = TRUE), ][1:top_n_interactions, ] + } + + if (nrow(edges_df_agg) == 0) { + return(ggplot2::ggplot() + ggplot2::annotate("text", x=0.5,y=0.5,label="No interactions left after filtering.") + ggplot2::theme_minimal()) + } + + all_nodes_involved_final <- unique(c(edges_df_agg$from, edges_df_agg$to)) + if (length(all_nodes_involved_final) == 0) { + return(ggplot2::ggplot() + ggplot2::annotate("text", x=0.5,y=0.5,label="No nodes to plot.") + ggplot2::theme_minimal()) + } - return(heatmap_plot) + node_df <- data.frame(id = sort(all_nodes_involved_final)) + node_df$name <- sapply(node_df$id, function(id_val) { + if (!is.null(variable_names) && id_val > 0 && id_val <= length(variable_names)) { + variable_names[id_val] + } else { + as.character(id_val) # Default to numeric if no name + } + }) + + # igraph requires vertex IDs to be 1..N if supplying a vertex data frame. + # Map our potentially sparse numeric IDs (node_df$id) to a dense 1..N range. + id_to_igraph_map <- stats::setNames(seq_along(node_df$id), node_df$id) + edges_df_agg$from_igraph <- id_to_igraph_map[as.character(edges_df_agg$from)] + edges_df_agg$to_igraph <- id_to_igraph_map[as.character(edges_df_agg$to)] + + # Create a clean vertex data frame for igraph, ensuring no 'name' column if + # from/to are numeric indices, or ensure from/to match the 'name' column if they are characters. + # Since from_igraph/to_igraph are 1..N dense indices, vertex_df should just have attributes. + # igraph will map row 1 of vertex_df to id 1, row 2 to id 2, etc. + + vertex_df_for_igraph <- data.frame(name = node_df$name, stringsAsFactors = FALSE) + # Any other node attributes can be added here. + + # Create the graph + # Start with an empty graph + graph <- igraph::make_empty_graph(n = nrow(node_df), directed = FALSE) + + # Add vertex attributes + # igraph::V(graph) gives vertex sequence; names are 1, 2, ..., nrow(node_df) by default + igraph::V(graph)$name_attr <- node_df$name # Use a different name for the attribute internally + # Add original sparse IDs if needed for other purposes, though not directly for plotting name + igraph::V(graph)$original_id <- node_df$id + + # Add edges + # Convert from/to_igraph to matrix for add_edges + edge_list_matrix <- as.matrix(edges_df_agg[, c("from_igraph", "to_igraph")]) + + if (nrow(edge_list_matrix) > 0) { + graph <- igraph::add_edges(graph, t(edge_list_matrix)) # add_edges expects 2-row matrix + + # Add edge attributes + igraph::E(graph)$weight <- edges_df_agg$weight + igraph::E(graph)$sign <- as.factor(edges_df_agg$sign) # Ensure it's a factor + # igraph::E(graph)$original_term_label_str <- edges_df_agg$original_term_label_str # If needed + } + + + # Check for isolated nodes (nodes in node_df but not in any edge after filtering) + # They are already in the graph due to make_empty_graph(n = nrow(node_df)) + + # ggraph plotting + # geom_node_text will use V(graph)$name_attr IF we map it to 'name' in aes or if it's the default ggraph looks for. + # To be safe, explicitly map it in ggraph's aes. + + gg_plot <- ggraph::ggraph(graph, layout = layout) + + ggraph::geom_edge_fan(ggplot2::aes(edge_width = .data$weight, edge_color = .data$sign), + alpha = 0.6, arrow = NULL, end_cap = ggraph::circle(3, 'mm')) + + ggraph::scale_edge_width_continuous(range = c(0.5, 4), name = "Strength") + + ggraph::scale_edge_color_manual(values = c("-1" = "firebrick", "1" = "steelblue", "0" = "grey50"), + name = "Sign of Coeff.", drop = FALSE) + + ggraph::geom_node_point(size = 7, color = "skyblue", alpha = 0.8) + + # Explicitly tell ggraph to use 'name_attr' for the label aesthetic + ggraph::geom_node_text(ggplot2::aes(label = .data$name_attr), repel = TRUE, size = 3.5) + + ggraph::theme_graph(base_family = 'sans', plot_margin = ggplot2::margin(1,1,1,1)) + + ggplot2::labs(title = paste(interaction_order, "-Order Interaction Network", sep="")) + + return(gg_plot) } + + + + diff --git a/man/plot.nn2poly.Rd b/man/plot.nn2poly.Rd index 134ab5e..73d122c 100644 --- a/man/plot.nn2poly.Rd +++ b/man/plot.nn2poly.Rd @@ -4,27 +4,38 @@ \alias{plot.nn2poly} \title{Plot method for \code{nn2poly} objects.} \usage{ -\method{plot}{nn2poly}(x, ..., n = NULL) +\method{plot}{nn2poly}(x, type = "bar", ..., n = NULL) } \arguments{ \item{x}{A \code{nn2poly} object, as returned by the \pkg{nn2poly} algorithm.} +\item{type}{A string containing the type of plot chosen. Currently available +plots are "bar" (default) for coefficient bar plots and "heatmap" for a +heatmap of second-order terms.} + \item{...}{Ignored.} -\item{n}{An integer denoting the number of coefficients to be plotted, -after ordering them by absolute magnitude.} +\item{n}{An integer denoting the number of coefficients to be plotted in the +"bar" case, after ordering them by absolute magnitude.} } \value{ -A plot showing the \code{n} most important coefficients. +A ggplot object, depending on the \code{type} chosen. } \description{ A function that takes a polynomial (or several ones) as given by the -\pkg{nn2poly} algorithm, and then plots their absolute magnitude as barplots -to be able to compare the most important coefficients. +\pkg{nn2poly} algorithm, and then provides a plot depending on the chosen +type to be able to compare the most important coefficients. } \details{ The plot method represents only the polynomials at the final layer, even if -\code{x} is generated using \code{nn2poly()} with \code{keep_layers=TRUE}. +\code{x} is generated using \code{nn2poly()} with \code{keep_layers=TRUE}. The plot types +are as follows: +\itemize{ +\item The bar plot represents coefficients ordered by absolute value. +\item The heatmap displays squared terms on the diagonal and pairwise +interactions off-diagonal, with colors indicating the magnitude and sign +of coefficients. +} } \examples{ # --- Single polynomial output --- @@ -52,6 +63,12 @@ plot(final_poly) # one plot per output unit plot(final_poly, n = 5) +# Plot a heatmap for second order terms. +plot(final_poly, type = "heatmap") + + + + # --- Multiple output polynomials --- # Build a NN structure with random weights, with 2 (+ bias) inputs, # 4 (+bias) neurons in the first hidden layer with "tanh" activation @@ -73,7 +90,7 @@ final_poly <- nn2poly(nn_object, max_order = 3) # Plot all the coefficients, one plot per output unit plot(final_poly) -# Plot only the 5 most important coeffcients (by absolute magnitude) +# Plot only the 5 most important coefficients (by absolute magnitude) # one plot per output unit plot(final_poly, n = 5) diff --git a/vignettes/source/_nn2poly-01-introduction.Rmd b/vignettes/source/_nn2poly-01-introduction.Rmd index 2ef9121..c6719d1 100644 --- a/vignettes/source/_nn2poly-01-introduction.Rmd +++ b/vignettes/source/_nn2poly-01-introduction.Rmd @@ -295,12 +295,131 @@ We can also plot the $n$ most important coefficients in absolute value to compar In this case we can see how the two most important obtained coefficients are `2,3` and `1`, precisely the two terms appearing in the original polynomial $4x_1 - 3 x_2x_3$. However, other interactions of order 3 appear to be also relevant, which is caused by the Taylor expansions not being controlled as we have not imposed constraints on the neural network weights training. ```{r reg-n-important} -plot(final_poly, n=8) +plot(final_poly, n=8, type = "bar") ``` +```{r} +plot(final_poly,type = "heatmap") +``` + + +```{r} +# Assuming 'final_poly' and 'test_x' are from your vignette +# library(nn2poly) # and other necessary libraries like ggplot2 + +# 1. Get monomial predictions for all test data +prediction_monomials <- predict(object = final_poly, + newdata = test_x, + monomials = TRUE) +dim(prediction_monomials) # Should be [125, 20, 1] in the vignette example + +# 2. Generate the local contributions plot for the first observation +plot(final_poly, + type = "local_contributions", + newdata_monomials = prediction_monomials, + observation_index = 1, + poly_output_index = 1, # Only one output in this example + max_order_to_display = 3) # Show up to 3rd order contributions + +``` + + +```{r} +# Assuming 'final_poly', 'test_x' are from your vignette +prediction_monomials <- predict(object = final_poly, + newdata = test_x, + monomials = TRUE) +dim(prediction_monomials) # Should be [125, 20, 1] in the vignette example + +# --- Beeswarm Plot --- + +# Color by the second original feature +plot(final_poly, + type = "beeswarm", + newdata_monomials = prediction_monomials, + original_feature_data = test_x, + color_by_feature = 3, # Color by the second column of test_x + poly_output_index = 1, + top_n_terms = 10) + +``` + + + +```{r} + +# Assuming 'final_poly' and 'test_x' are available from your vignette +# And the plot_interaction_surface_internal function is defined and accessible. + +# Example 1.1: Plotting the interaction surface for features 2 and 3 using numeric indices +plot(final_poly, + type = "interaction_surface", + feature_pair = c(2, 3), # Specify the pair of features by their indices + original_feature_data = test_x, # Provide the original data for ranges and means + grid_resolution = 30, # Increase resolution for a smoother surface (optional) + variable_names = c("X1", "X2", "X3")) # Provide variable names for axis labels (optional) + +# Example 1.2: Plotting the interaction surface for features 1 and 2 using feature names (if test_x had column names or variable_names are provided and match) +# Let's assume variable_names were used to define the column names conceptually +# If test_x has columns named "X1", "X2", "X3" +plot(final_poly, + type = "interaction_surface", + feature_pair = c("X1", "X2"), # Specify the pair of features by their names + original_feature_data = test_x, + grid_resolution = 30) + +# If test_x does NOT have column names, you need to provide variable_names +plot(final_poly, + type = "interaction_surface", + feature_pair = c("X1", "X2"), # Specify the pair of features by their names + original_feature_data = test_x, + grid_resolution = 30, + variable_names = c("X1", "X2", "X3")) # These names are used for matching feature_pair + +``` + + +```{r} +# Assuming 'final_poly', 'test_x', and 'prediction_monomials' are available +# And the plot_interaction_network_internal function is defined and accessible. + +# Example 2.1: Plotting the 2nd-order interaction network using coefficient magnitudes +plot(final_poly, + type = "interaction_network", + interaction_order_network = 2, # Focus on 2nd-order interactions (pairwise projections) + metric_network = "coefficient_abs", # Use the absolute value of coefficients for strength + variable_names = c("X1", "X2", "X3")) # Provide variable names for nodes + +# Example 2.2: Plotting the 2nd-order interaction network using mean absolute monomial values +plot(final_poly, + type = "interaction_network", + interaction_order_network = 2, + metric_network = "mean_monomial_abs", # Use average monomial value across data for strength + newdata_monomials = prediction_monomials, # Required for mean_monomial_abs metric + variable_names = c("X1", "X2", "X3")) + +# Example 2.3: Plotting the 3rd-order interaction network (if max_order >= 3 was used in nn2poly) +# This will project 3rd order terms like x1*x2*x3 onto pairs (x1,x2), (x1,x3), (x2,x3) +plot(final_poly, + type = "interaction_network", + interaction_order_network = 3, # Look for 3rd-order terms + metric_network = "coefficient_abs", + variable_names = c("X1", "X2", "X3")) + +# Example 2.4: Plotting the top N 2nd-order interactions +plot(final_poly, + type = "interaction_network", + interaction_order_network = 2, + metric_network = "coefficient_abs", + top_n_interactions = 5, # Display only the 5 strongest projected pairwise interactions + variable_names = c("X1", "X2", "X3")) +``` + + Another convenient plot to show how the algorithm is affected by each layer can be obtained with `nn2poly:::plot_taylor_and_activation_potentials()`, where the activation potentials at each neuron are computed and presented over the Taylor expansion approximation of the activation function at each layer. + In this case, as we have not used constraints in the NN training, the activation potentials are not strictly centered around zero. From 330ad2f50608b727abd59f4c5e9a546b69d964cb Mon Sep 17 00:00:00 2001 From: moralapablo Date: Sat, 10 May 2025 23:38:25 +0200 Subject: [PATCH 3/4] Fixed error with single intercept polynomials --- R/eval_poly.R | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/R/eval_poly.R b/R/eval_poly.R index 0efe0ce..015fa0c 100644 --- a/R/eval_poly.R +++ b/R/eval_poly.R @@ -51,6 +51,23 @@ eval_poly <- function(poly, newdata, monomials = FALSE) { newdata <- preprocess_newdata(newdata) + # Check if there are labels with bigger numbers than the number of columns + if (length(poly$labels) > 0) { # Only check if there are labels + # Find the maximum variable index referenced in the polynomial labels + # Exclude 0 (intercept) from consideration as a variable index + all_vars_in_labels <- unlist(lapply(poly$labels, function(lab) lab[lab > 0])) + if (length(all_vars_in_labels) > 0) { + max_var_index_poly <- max(all_vars_in_labels, na.rm = TRUE) + + # Check against number of columns in newdata + if (max_var_index_poly > 0 && ncol(newdata) < max_var_index_poly) { + stop(paste0("Polynomial requires at least ", max_var_index_poly, + " variable(s), but newdata only has ", ncol(newdata), " column(s)."), + call. = FALSE) + } + } + } + aux <- preprocess_poly(poly) poly <- aux$poly original_intercept_pos_for_reordering <- aux$intercept_position @@ -88,18 +105,18 @@ eval_poly <- function(poly, newdata, monomials = FALSE) { } # Loop over all terms (labels) except the intercept - for (j in start_loop:length(values_k)) { - - label_j <- poly$labels[[j]] - - var_prod <- multiply_variables(label_j, newdata) - + if (start_loop <= length(poly$labels)) { + for (j in start_loop:length(poly$labels)) { - # Here instead of adding response over the loop as in the normal - # eval_poly, store it in the appropriate position. - response[,j,k] = values_k[j] * var_prod + label_j <- poly$labels[[j]] + coefficient_val <- values_k[j] + var_prod <- multiply_variables(label_j, newdata) + # Here instead of adding response over the loop as in the normal + # eval_poly, store it in the appropriate position. + response[,j,k] = coefficient_val * var_prod + } } # In case the intercept has been moved, we reorder it to its original From 3425147a2e6821abf82b0055824b9d68f423747f Mon Sep 17 00:00:00 2001 From: moralapablo Date: Sat, 10 May 2025 23:45:49 +0200 Subject: [PATCH 4/4] Unified styl in plots and some minor fixes --- R/nn2poly_methods.R | 98 +++++++------------ vignettes/source/_nn2poly-01-introduction.Rmd | 16 +-- 2 files changed, 41 insertions(+), 73 deletions(-) diff --git a/R/nn2poly_methods.R b/R/nn2poly_methods.R index fa18387..a8f8ef3 100644 --- a/R/nn2poly_methods.R +++ b/R/nn2poly_methods.R @@ -272,7 +272,6 @@ plot.nn2poly <- function(x, type = "bar", ..., } if (type == "bar") { - if (!is.null(list(...)$n) && is.null(n)) n <- list(...)$n plot_object <- plot_bar(poly_for_plot, n = n, variable_names = variable_names, min_order = min_order) } else if (type == "heatmap") { plot_object <- plot_heatmap(poly_for_plot, variable_names = variable_names) @@ -472,7 +471,7 @@ plot_bar <- function(poly_obj, n = NULL, variable_names = NULL, min_order = 0) { if (length(labels_to_keep) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = paste0("No terms found with order >= ", min_order, ".")) + - ggplot2::theme_minimal() + + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle(paste0("Coefficients (Order >= ", min_order, ")"))) } @@ -493,7 +492,7 @@ plot_bar <- function(poly_obj, n = NULL, variable_names = NULL, min_order = 0) { if (n == 0) { # Case where n becomes 0 after filtering if nrow(poly_obj_filtered$values) was 0 return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = paste0("No terms to plot after filtering (n=0).")) + - ggplot2::theme_minimal() + + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle(paste0("Coefficients (Order >= ", min_order, ")"))) } @@ -550,7 +549,7 @@ plot_bar <- function(poly_obj, n = NULL, variable_names = NULL, min_order = 0) { if (nrow(all_plot_df) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No coefficients to plot after processing.") + - ggplot2::theme_minimal() + + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle(paste0("Coefficients (Order >= ", min_order, ")"))) } @@ -579,15 +578,13 @@ plot_bar <- function(poly_obj, n = NULL, variable_names = NULL, min_order = 0) { y = "Polynomial Term") - p <- p + ggplot2::theme_minimal() + + p <- p + ggplot2::theme_minimal(base_size = 10) + ggplot2::theme( axis.text.x = ggplot2::element_text(size=10), axis.text.y = ggplot2::element_text(size=10), legend.position = "top", legend.direction = "horizontal", - plot.title = ggplot2::element_text(hjust = 0.5, face="bold"), - panel.grid.major.y = ggplot2::element_blank(), - panel.grid.minor.y = ggplot2::element_blank(), + plot.title = ggplot2::element_text(hjust = 0.5, size = 14), strip.background = ggplot2::element_rect(fill="grey90", linetype="blank"), strip.text = ggplot2::element_text(face="bold") ) @@ -622,7 +619,7 @@ plot_heatmap <- function(poly_obj, variable_names = NULL) { # Changed x to poly_ if (!any(is_second_order)) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No second-order terms with positive variable indices found.") + - ggplot2::theme_minimal() + + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Second-Order Coefficients Heatmap")) } @@ -634,7 +631,7 @@ plot_heatmap <- function(poly_obj, variable_names = NULL) { # Changed x to poly_ if (length(present_vars) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x=0.5, y=0.5, label="No variables found in second-order terms.") + - ggplot2::theme_minimal() + + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Second-Order Coefficients Heatmap")) } max_var_idx <- max(present_vars) @@ -642,7 +639,7 @@ plot_heatmap <- function(poly_obj, variable_names = NULL) { # Changed x to poly_ if(max_var_idx == 0){ # Should not happen if all(lab > 0) filter is effective return(ggplot2::ggplot() + ggplot2::annotate("text", x=0.5, y=0.5, label="Max variable index is 0.") + - ggplot2::theme_minimal() + + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Second-Order Coefficients Heatmap")) } @@ -705,10 +702,9 @@ plot_heatmap <- function(poly_obj, variable_names = NULL) { # Changed x to poly_ ggplot2::theme_minimal(base_size = 10) + ggplot2::labs(x = "Variable", y = "Variable") + ggplot2::theme( - axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, vjust = 1), - axis.text.y = ggplot2::element_text(), # Default angle - panel.grid = ggplot2::element_blank(), - legend.position = "bottom", + axis.text.x = ggplot2::element_text(), + axis.text.y = ggplot2::element_text(), + legend.position = "top", legend.direction = "horizontal", plot.title = ggplot2::element_text(hjust = 0.5, size = 14), strip.background = ggplot2::element_rect(fill="grey90", linetype="blank"), strip.text = ggplot2::element_text(face="bold") @@ -810,7 +806,7 @@ plot_local_contributions_internal <- function(poly_obj, if (length(contributions_list) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No feature contributions to display for this observation (or selected orders).") + - ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Local Feature Contributions")) } plot_df <- do.call(rbind, contributions_list) @@ -825,7 +821,7 @@ plot_local_contributions_internal <- function(poly_obj, if (nrow(plot_df_agg) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "All aggregated feature contributions are negligible.") + - ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Local Feature Contributions")) } # --- Prepare for Plotting --- @@ -834,7 +830,7 @@ plot_local_contributions_internal <- function(poly_obj, if(is.infinite(max_present_order) || max_present_order < 1) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No contributions to display for the selected orders after aggregation.") + - ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Local Feature Contributions")) } order_suffix <- function(k) { @@ -858,7 +854,7 @@ plot_local_contributions_internal <- function(poly_obj, if (p == 0 && nrow(plot_df_agg) > 0) { p <- max(plot_df_agg$variable_idx, na.rm = TRUE) } if (p == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x=0.5, y=0.5, label="No variables found in contributions.") + - ggplot2::theme_minimal() + ggplot2::ggtitle("Local Feature Contributions")) + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Local Feature Contributions")) } current_axis_var_labels <- character(p) # Full list of potential labels @@ -886,14 +882,13 @@ plot_local_contributions_internal <- function(poly_obj, fill = .data$term_order_str)) + ggplot2::geom_col(position = "stack", width = 0.7, na.rm = TRUE) + # na.rm for safety ggplot2::geom_hline(yintercept = 0, linetype = "solid", color = "black") + - # ggplot2::scale_fill_manual(values = active_colors, name = "Term Order", drop = FALSE) + # drop=FALSE ensures legend consistency + ggplot2::scale_fill_brewer(palette = "Set1", name = "Term Order", drop = FALSE) + ggplot2::labs(title = "Local Feature Contributions", x = "Feature", y = "Contribution to Prediction") + - ggplot2::theme_minimal(base_size = 11) + - ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, vjust = 1), - legend.position = "top", - plot.title = ggplot2::element_text(hjust = 0.5)) + ggplot2::theme_minimal(base_size = 10) + + ggplot2::theme(legend.position = "top", legend.direction = "horizontal", + plot.title = ggplot2::element_text(hjust = 0.5, size = 14)) return(final_plot) } @@ -945,7 +940,7 @@ plot_beeswarm_internal <- function(poly_obj, if (length(term_indices_to_plot) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "Only an intercept term found. No beeswarm plot to generate.") + - ggplot2::theme_minimal() + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) } } current_labels_raw_no_intercept <- term_labels_raw[term_indices_to_plot] @@ -953,7 +948,7 @@ plot_beeswarm_internal <- function(poly_obj, if (ncol(current_monomial_values_no_intercept) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text",x=0.5,y=0.5,label="No non-intercept terms to plot.") + - ggplot2::theme_minimal() + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) } # --- Term Importance and Selection --- @@ -1025,7 +1020,7 @@ plot_beeswarm_internal <- function(poly_obj, if (nrow(plot_df_long) == 0) { return(ggplot2::ggplot() + ggplot2::annotate("text", x = 0.5, y = 0.5, label = "No data to plot after filtering terms.") + - ggplot2::theme_minimal() + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) + ggplot2::theme_minimal(base_size = 10) + ggplot2::ggtitle("Term Contribution Beeswarm Plot")) } plot_df_long$term_label_str <- factor(plot_df_long$term_label_str, @@ -1039,16 +1034,16 @@ plot_beeswarm_internal <- function(poly_obj, beeswarm_plot <- ggplot2::ggplot(plot_df_long, ggplot2::aes(x = .data$monomial_value, y = .data$term_label_str, - color = .data$coloring_value)) + # Changed to coloring_value + colour = .data$coloring_value)) + ggplot2::geom_vline(xintercept = 0, linetype = "dashed", color = "grey50") + - ggbeeswarm::geom_quasirandom(alpha = 0.7, size = 1.5, shape = 16, groupOnX = FALSE, na.rm = TRUE) + - ggplot2::scale_color_gradient(low = "blue", high = "red", name = color_legend_title, na.value = "grey70") + + ggbeeswarm::geom_quasirandom(alpha = 1, size = 1.5, shape = 16, groupOnX = FALSE, na.rm = TRUE) + + ggplot2::scale_colour_gradient2(low = "#F8766D", mid = "gray80", high = "#00BA38", midpoint = 0, name = color_legend_title, na.value = "grey70") + ggplot2::labs(title = "Term Contribution Beeswarm Plot", x = "Monomial Term Value (Contribution)", y = "Polynomial Term") + - ggplot2::theme_minimal(base_size = 11) + + ggplot2::theme_minimal(base_size = 10) + ggplot2::theme(legend.position = "right", - plot.title = ggplot2::element_text(hjust = 0.5)) + plot.title = ggplot2::element_text(hjust = 0.5, size = 14)) return(beeswarm_plot) } @@ -1056,37 +1051,6 @@ plot_beeswarm_internal <- function(poly_obj, # Helper to identify and sum relevant terms for the surface plot -# eval_selected_terms_on_surface <- function(poly_obj, # Should be single output (values is a vector or 1-col matrix) -# newdata_row, # Single row of data for all p features -# feature_pair_indices) { # Numeric indices of the two features -# value_sum <- 0 -# -# for (k_term in seq_along(poly_obj$labels)) { -# term_label <- poly_obj$labels[[k_term]] -# term_coeff <- poly_obj$values[k_term] # Already single output -# -# # Identify variables in the current term label (excluding 0 for intercept) -# vars_in_current_label <- term_label[term_label > 0] -# -# is_relevant <- FALSE -# if (length(term_label) == 1 && term_label[1] == 0) { # Intercept term -# is_relevant <- TRUE -# } else if (length(vars_in_current_label) > 0 && # Has some variables -# all(vars_in_current_label %in% feature_pair_indices)) { # All variables in term are from the chosen pair -# is_relevant <- TRUE -# } -# -# if (is_relevant) { -# # Evaluate this single term -# # Create a temporary mini-polynomial for this one term -# mini_poly <- list(labels = list(term_label), values = matrix(term_coeff, ncol = 1)) -# value_sum <- value_sum + eval_poly(poly = mini_poly, newdata = newdata_row) -# } -# } -# return(value_sum) -# } - - eval_selected_terms_on_surface <- function(poly_obj, newdata_row, feature_pair_indices) { @@ -1220,7 +1184,10 @@ plot_interaction_surface_internal <- function(poly_obj, ggplot2::scale_fill_viridis_c(name = "Summed Effect") + ggplot2::labs(title = paste("Interaction Surface:", feat_name1, "&", feat_name2), x = feat_name1, y = feat_name2) + - ggplot2::theme_minimal() + + ggplot2::theme_minimal(base_size = 10) + + ggplot2::theme(plot.title = ggplot2::element_text(hjust = 0.5, size = 14), + legend.position = "top", + legend.direction = "horizontal" ) + ggplot2::coord_equal() # Often good for surfaces return(p) @@ -1401,12 +1368,13 @@ plot_interaction_network_internal <- function(poly_obj, ggraph::geom_edge_fan(ggplot2::aes(edge_width = .data$weight, edge_color = .data$sign), alpha = 0.6, arrow = NULL, end_cap = ggraph::circle(3, 'mm')) + ggraph::scale_edge_width_continuous(range = c(0.5, 4), name = "Strength") + - ggraph::scale_edge_color_manual(values = c("-1" = "firebrick", "1" = "steelblue", "0" = "grey50"), + ggraph::scale_edge_color_manual(values = c("-1" = "#F8766D", "1" = "#00BA38", "0" = "grey70"), name = "Sign of Coeff.", drop = FALSE) + ggraph::geom_node_point(size = 7, color = "skyblue", alpha = 0.8) + # Explicitly tell ggraph to use 'name_attr' for the label aesthetic ggraph::geom_node_text(ggplot2::aes(label = .data$name_attr), repel = TRUE, size = 3.5) + ggraph::theme_graph(base_family = 'sans', plot_margin = ggplot2::margin(1,1,1,1)) + + ggplot2::theme(plot.title = ggplot2::element_text(hjust = 0.5, size = 14)) + ggplot2::labs(title = paste(interaction_order, "-Order Interaction Network", sep="")) return(gg_plot) diff --git a/vignettes/source/_nn2poly-01-introduction.Rmd b/vignettes/source/_nn2poly-01-introduction.Rmd index c6719d1..cb26021 100644 --- a/vignettes/source/_nn2poly-01-introduction.Rmd +++ b/vignettes/source/_nn2poly-01-introduction.Rmd @@ -361,10 +361,10 @@ plot(final_poly, # Example 1.2: Plotting the interaction surface for features 1 and 2 using feature names (if test_x had column names or variable_names are provided and match) # Let's assume variable_names were used to define the column names conceptually -# If test_x has columns named "X1", "X2", "X3" +# If test_x has columns named "V1", "V2", "V3" plot(final_poly, type = "interaction_surface", - feature_pair = c("X1", "X2"), # Specify the pair of features by their names + feature_pair = c("V1", "V2"), # Specify the pair of features by their names original_feature_data = test_x, grid_resolution = 30) @@ -391,12 +391,12 @@ plot(final_poly, variable_names = c("X1", "X2", "X3")) # Provide variable names for nodes # Example 2.2: Plotting the 2nd-order interaction network using mean absolute monomial values -plot(final_poly, - type = "interaction_network", - interaction_order_network = 2, - metric_network = "mean_monomial_abs", # Use average monomial value across data for strength - newdata_monomials = prediction_monomials, # Required for mean_monomial_abs metric - variable_names = c("X1", "X2", "X3")) +# plot(final_poly, +# type = "interaction_network", +# interaction_order_network = 2, +# metric_network = "mean_monomial_abs", # Use average monomial value across data for strength +# newdata_monomials = prediction_monomials, # Required for mean_monomial_abs metric +# variable_names = c("X1", "X2", "X3")) # Example 2.3: Plotting the 3rd-order interaction network (if max_order >= 3 was used in nn2poly) # This will project 3rd order terms like x1*x2*x3 onto pairs (x1,x2), (x1,x3), (x2,x3)