Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export(compile_keras_grid)
export(create_keras_functional_spec)
export(create_keras_sequential_spec)
export(extract_keras_history)
export(extract_keras_summary)
export(extract_keras_model)
export(extract_valid_grid)
export(generic_functional_fit)
export(generic_sequential_fit)
Expand Down
156 changes: 114 additions & 42 deletions R/compile_keras_grid.R
Original file line number Diff line number Diff line change
@@ -1,47 +1,89 @@
#' Compile Keras Models over a Grid of Hyperparameters
#' Compile and Validate Keras Model Architectures
#'
#' @title Compile Keras Models Over a Grid of Hyperparameters
#' @description
#' This function allows you to build and compile multiple Keras models based on a
#' `parsnip` model specification and a grid of hyperparameters, without actually
#' fitting them. It's a valuable tool for validating model architectures and
#' catching potential errors early in the modeling process.
#' Pre-compiles Keras models for each hyperparameter combination in a grid.
#'
#' This function is a powerful debugging tool to use before running a full
#' `tune::tune_grid()`. It allows you to quickly validate multiple model
#' architectures, ensuring they can be successfully built and compiled without
#' the time-consuming process of actually fitting them. It helps catch common
#' errors like incompatible layer shapes or invalid argument values early.
#'
#' @details
#' The function operates by iterating through each row of the provided `grid`.
#' For each combination of hyperparameters, it:
#' \enumerate{
#' \item Constructs the appropriate Keras model (Sequential or Functional) based
#' on the `spec`.
#' \item Compiles the model using the specified optimizer, loss, and metrics.
#' \item Wraps the process in a `try-catch` block to gracefully handle any
#' errors that might occur during model instantiation or compilation (e.g.,
#' due to incompatible layer shapes or invalid argument values).
#' }
#' The output is a `tibble` where each row corresponds to a row in the input
#' `grid`. It includes the original hyperparameters, the compiled Keras model
#' object (or a string with the error message if compilation failed), and a
#' summary of the model's architecture.
#' The function iterates through each row of the provided `grid`. For each
#' hyperparameter combination, it attempts to build and compile the Keras model
#' defined by the `spec`. The process is wrapped in a `try-catch` block to
#' gracefully handle and report any errors that occur during model instantiation
#' or compilation.
#'
#' The output is a tibble that mirrors the input `grid`, with additional columns
#' containing the compiled model object or the error message, making it easy to
#' inspect which architectures are valid.
#'
#' @param spec A `parsnip` model specification created by
#' `create_keras_sequential_spec()` or `create_keras_functional_spec()`.
#' @param grid A `tibble` or `data.frame` containing the grid of hyperparameters
#' to evaluate. Each row represents a unique model architecture to be compiled.
#' @param x A data frame or matrix of predictors. This is used to infer the
#' `input_shape` for the Keras model.
#' @param y A vector of outcomes. This is used to infer the output shape and
#' the default loss function.
#' @param y A vector or factor of outcomes. This is used to infer the output
#' shape and the default loss function for the Keras model.
#'
#' @return A `tibble` with the following columns:
#' \itemize{
#' \item Columns from the input `grid`.
#' \item `compiled_model`: A list-column containing the compiled Keras model
#' objects. If compilation failed for a specific hyperparameter set, this
#' column will contain a character string with the error message.
#' \item `model_summary`: A list-column containing a character string with the
#' output of `keras3::summary_keras_model()` for each successfully compiled
#' model.
#' objects. If compilation failed, the element will be `NULL`.
#' \item `error`: A list-column containing `NA` for successes or a
#' character string with the error message for failures.
#' }
#'
#' @examples
#' \dontrun{
#' if (keras::is_keras_available()) {
#'
#' # 1. Define a kerasnip model specification
#' create_keras_sequential_spec(
#' model_name = "my_mlp",
#' layer_blocks = list(
#' input_block,
#' hidden_block,
#' output_block
#' ),
#' mode = "classification"
#' )
#'
#' mlp_spec <- my_mlp(
#' hidden_units = tune(),
#' compile_loss = "categorical_crossentropy",
#' compile_optimizer = "adam"
#' )
#'
#' # 2. Create a hyperparameter grid
#' # Include an invalid value (-10) to demonstrate error handling
#' param_grid <- tibble::tibble(
#' hidden_units = c(32, 64, -10)
#' )
#'
#' # 3. Prepare dummy data
#' x_train <- matrix(rnorm(100 * 10), ncol = 10)
#' y_train <- factor(sample(0:1, 100, replace = TRUE))
#'
#' # 4. Compile models over the grid
#' compiled_grid <- compile_keras_grid(
#' spec = mlp_spec,
#' grid = param_grid,
#' x = x_train,
#' y = y_train
#' )
#'
#' print(compiled_grid)
#'
#' # 5. Inspect the results
#' # The row with `hidden_units = -10` will show an error.
#' }
#' }
#' @importFrom dplyr bind_rows filter select
#' @importFrom cli cli_h1 cli_alert_danger cli_h2 cli_text cli_bullets cli_code cli_alert_info cli_alert_success
#' @export
Expand Down Expand Up @@ -110,19 +152,14 @@ compile_keras_grid <- function(spec, grid, x, y) {
{
model <- do.call(build_fn, args)
# Capture the model summary
summary_char <- utils::capture.output(summary(
model
))
list(
compiled_model = list(model),
model_summary = paste(summary_char, collapse = "\n"),
error = NA_character_
)
},
error = function(e) {
list(
compiled_model = list(NULL),
model_summary = NA_character_,
error = as.character(e$message)
)
}
Expand All @@ -136,24 +173,43 @@ compile_keras_grid <- function(spec, grid, x, y) {
dplyr::bind_rows(results)
}

#' Extract Valid Grid from Compilation Results
#' Filter a Grid to Only Valid Hyperparameter Sets
#'
#' @title Extract Valid Grid from Compilation Results
#' @description
#' This helper function filters the results from `compile_keras_grid()` to
#' return a new hyperparameter grid containing only the combinations that
#' compiled successfully.
#'
#' @details
#' After running `compile_keras_grid()`, you can use this function to remove
#' problematic hyperparameter combinations before proceeding to the full
#' `tune::tune_grid()`.
#'
#' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
#'
#' @return A tibble containing the subset of the original grid that resulted in
#' a successful model compilation (i.e., where the `error` column is `NA`).
#' The columns for `compiled_model`, `model_summary`, and `error` are removed.
#' a successful model compilation. The `compiled_model` and `error` columns
#' are removed, leaving a clean grid ready for tuning.
#'
#' @examples
#' \dontrun{
#' # Continuing the example from `compile_keras_grid`:
#'
#' # `compiled_grid` contains one row with an error.
#' valid_grid <- extract_valid_grid(compiled_grid)
#'
#' # `valid_grid` now only contains the rows that compiled successfully.
#' print(valid_grid)
#'
#' # This clean grid can now be passed to tune::tune_grid().
#' }
#' @export
extract_valid_grid <- function(compiled_grid) {
if (
!is.data.frame(compiled_grid) ||
!all(
c("error", "compiled_model", "model_summary") %in% names(compiled_grid)
c("error", "compiled_model") %in% names(compiled_grid)
)
) {
stop(
Expand All @@ -162,20 +218,36 @@ extract_valid_grid <- function(compiled_grid) {
}
compiled_grid %>%
dplyr::filter(is.na(error)) %>%
dplyr::select(-compiled_model, -model_summary, -error)
dplyr::select(-c(compiled_model, error))
}

#' Inform about Compilation Errors
#' Display a Summary of Compilation Errors
#'
#' @title Inform About Compilation Errors
#' @description
#' This helper function inspects the results from `compile_keras_grid()` and
#' prints a formatted summary of any compilation errors that occurred.
#' prints a formatted, easy-to-read summary of any compilation errors that
#' occurred.
#'
#' @details
#' This is most useful for interactive debugging of complex tuning grids where
#' some hyperparameter combinations may lead to invalid Keras models.
#'
#' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
#' @param n The maximum number of errors to display.
#' @param n A single integer for the maximum number of distinct errors to
#' display in detail.
#'
#' @return Invisibly returns the input `compiled_grid`. Called for its side
#' effect of printing to the console.
#' effect of printing a summary to the console.
#'
#' @examples
#' \dontrun{
#' # Continuing the example from `compile_keras_grid`:
#'
#' # `compiled_grid` contains one row with an error.
#' # This will print a formatted summary of that error.
#' inform_errors(compiled_grid)
#' }
#' @export
inform_errors <- function(compiled_grid, n = 10) {
if (
Expand All @@ -195,7 +267,7 @@ inform_errors <- function(compiled_grid, n = 10) {

for (i in 1:min(nrow(error_grid), n)) {
row <- error_grid[i, ]
params <- row %>% dplyr::select(-compiled_model, -model_summary, -error)
params <- row %>% dplyr::select(-c(compiled_model, error))
cli::cli_h2("Error {i}/{nrow(error_grid)}")
cli::cli_text("Hyperparameters:")
cli::cli_bullets(paste0(names(params), ": ", as.character(params)))
Expand All @@ -209,4 +281,4 @@ inform_errors <- function(compiled_grid, n = 10) {
cli::cli_alert_success("All models compiled successfully!")
}
invisible(compiled_grid)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[air] reported by reviewdog 🐶

Suggested change
}
}

77 changes: 46 additions & 31 deletions R/generic_functional_fit.R
Original file line number Diff line number Diff line change
@@ -1,39 +1,29 @@
#' Generic Keras Functional API Model Fitting Implementation
#' Generic Fitting Function for Functional Keras Models
#'
#' @title Internal Fitting Engine for Functional API Models
#' @description
#' This function is the internal engine for fitting models generated by
#' `create_keras_functional_spec()`. It is not intended to be called directly
#' by the user.
#' This function serves as the internal engine for fitting `kerasnip` models that
#' are based on the Keras functional API. It is not intended to be called
#' directly by the user. The function is invoked by `parsnip::fit()` when a
#' `kerasnip` functional model specification is used.
#'
#' @details
#' This function performs the following key steps:
#' The function orchestrates the three main steps of the model fitting process:
#' \enumerate{
#' \item \strong{Argument & Data Preparation:} It resolves arguments passed
#' from `parsnip` (handling `rlang_zap` objects for unspecified arguments)
#' and prepares the `x` and `y` data for Keras. It automatically determines
#' the `input_shape` from `x` and, for classification, the `num_classes`
#' from `y`.
#' \item \strong{Dynamic Model Construction:} It builds the Keras model graph
#' by processing the `layer_blocks` list.
#' \itemize{
#' \item \strong{Connectivity:} The graph is connected by matching the
#' argument names of each block function to the names of previously
#' defined blocks. For example, a block `function(input_a, ...)` will
#' receive the output tensor from the block named `input_a`.
#' \item \strong{Repetition:} It checks for `num_{block_name}` arguments
#' to repeat a block multiple times, creating a chain of identical
#' layers. A block can only be repeated if it has exactly one input
#' tensor from another block.
#' }
#' \item \strong{Model Compilation:} It compiles the final Keras model. The
#' compilation arguments (optimizer, loss, metrics) can be customized by
#' passing arguments prefixed with `compile_` (e.g., `compile_loss = "mae"`).
#' \item \strong{Model Fitting:} It calls `keras3::fit()` to train the model
#' on the prepared data.
#' \item \strong{Build and Compile:} It calls
#' `build_and_compile_functional_model()` to construct the Keras model
#' architecture based on the provided `layer_blocks` and hyperparameters.
#' \item \strong{Process Data:} It preprocesses the input (`x`) and output (`y`)
#' data into the format expected by Keras.
#' \item \strong{Fit Model:} It calls `keras3::fit()` with the compiled model
#' and processed data, passing along any fitting-specific arguments (e.g.,
#' `epochs`, `batch_size`, `callbacks`).
#' }
#'
#' @param x A data frame or matrix of predictors.
#' @param y A vector of outcomes.
#' @param formula A formula specifying the predictor and outcome variables,
#' passed down from the `parsnip::fit()` call.
#' @param data A data frame containing the training data, passed down from the
#' `parsnip::fit()` call.
#' @param layer_blocks A named list of layer block functions. This is passed
#' internally from the `parsnip` model specification.
#' @param ... Additional arguments passed down from the model specification.
Expand Down Expand Up @@ -61,14 +51,39 @@
#' \item `lvl`: A character vector of the outcome factor levels (for
#' classification) or `NULL` (for regression).
#' }
#'
#' @examples
#' # This function is not called directly by users.
#' # It is called internally by `parsnip::fit()`.
#' # For example:
#' \dontrun{
#' # create_keras_functional_spec(...) defines my_functional_model
#'
#' spec <- my_functional_model(hidden_units = 128, fit_epochs = 10) |>
#' set_engine("keras")
#'
#' # This call to fit() would invoke generic_functional_fit() internally
#' fitted_model <- fit(spec, y ~ x, data = training_data)
#' }
#' @keywords internal
#' @export
generic_functional_fit <- function(
x,
y,
formula,
data,
layer_blocks,
...
) {
# Separate predictors and outcomes from the processed data frame provided by parsnip
y_names <- all.vars(formula[[2]])
x_names <- all.vars(formula[[3]])

# Handle the `.` case for predictors
if ("." %in% x_names) {
x <- data[, !(names(data) %in% y_names), drop = FALSE]
} else {
x <- data[, x_names, drop = FALSE]
}
y <- data[, y_names, drop = FALSE]
# --- 1. Build and Compile Model ---
model <- build_and_compile_functional_model(x, y, layer_blocks, ...)

Expand Down
Loading
Loading