Skip to content

Commit

Permalink
[R-package] moved parameter validations up earlier in function calls (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Jan 25, 2020
1 parent 08fd53c commit d906428
Show file tree
Hide file tree
Showing 14 changed files with 267 additions and 79 deletions.
18 changes: 10 additions & 8 deletions R-package/R/callback.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ cb.reset.parameters <- function(new_params) {
# Run some checks in the beginning
init <- function(env) {

# Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1L

# Check for model environment
if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
if (is.null(env$model)) {
stop("Env should have a ", sQuote("model"))
}

# Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos
Expand All @@ -50,6 +49,9 @@ cb.reset.parameters <- function(new_params) {
)
}

# Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1L

# Check parameter names
for (n in pnames) {

Expand Down Expand Up @@ -285,14 +287,14 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Initialization function
init <- function(env) {

# Store evaluation length
eval_len <<- length(env$eval_list)

# Early stopping cannot work without metrics
if (eval_len == 0L) {
if (length(env$eval_list) == 0L) {
stop("For early stopping, valids must have at least one element")
}

# Store evaluation length
eval_len <<- length(env$eval_list)

# Check if verbose or not
if (isTRUE(verbose)) {
cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
Expand Down
16 changes: 11 additions & 5 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -781,15 +781,21 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
}

# Return new booster
if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename)))
if (!is.null(filename) && !file.exists(filename)) {
stop("lgb.load: file does not exist for supplied filename")
}
if (!is.null(filename)) {
return(invisible(Booster$new(modelfile = filename)))
}

# Load from model_str
if (!is.null(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be character")
}
# Return new booster
if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str)))
if (!is.null(model_str)) {
return(invisible(Booster$new(model_str = model_str)))
}

}

Expand Down Expand Up @@ -831,8 +837,8 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
}

# Check if file name is character
if (!is.character(filename)) {
stop("lgb.save: filename should be a character")
if (!(is.character(filename) && length(filename) == 1L)) {
stop("lgb.save: filename should be a string")
}

# Store booster
Expand Down
22 changes: 8 additions & 14 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ Dataset <- R6::R6Class(
info = list(),
...) {

# validate inputs early to avoid unnecessary computation
if (!(is.null(reference) || lgb.check.r6.class(reference, "lgb.Dataset"))) {
stop("lgb.Dataset: If provided, reference must be a ", sQuote("lgb.Dataset"))
}
if (!(is.null(predictor) || lgb.check.r6.class(predictor, "lgb.Predictor"))) {
stop("lgb.Dataset: If provided, predictor must be a ", sQuote("lgb.Predictor"))
}

# Check for additional parameters
additional_params <- list(...)

Expand All @@ -56,20 +64,6 @@ Dataset <- R6::R6Class(

}

# Check for dataset reference
if (!is.null(reference)) {
if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference")
}
}

# Check for predictor reference
if (!is.null(predictor)) {
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor")
}
}

# Check for matrix format
if (is.matrix(data)) {
# Check whether matrix is the correct type first ("double")
Expand Down
27 changes: 14 additions & 13 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ CVBooster <- R6::R6Class(
#' @description Cross validation logic used by LightGBM
#' @inheritParams lgb_shared_params
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is an R-matrix.
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
Expand Down Expand Up @@ -95,6 +95,19 @@ lgb.cv <- function(params = list()
, ...
) {

# validate parameters
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}

# If 'data' is not an lgb.Dataset, try to construct one using 'label'
if (!lgb.is.Dataset(data)) {
if (is.null(label)) {
stop("'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'")
}
data <- lgb.Dataset(data, label = label)
}

# Setup temporary variables
params <- append(params, list(...))
params$verbose <- verbose
Expand All @@ -103,10 +116,6 @@ lgb.cv <- function(params = list()
fobj <- NULL
feval <- NULL

if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}

# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down Expand Up @@ -141,14 +150,6 @@ lgb.cv <- function(params = list()
end_iteration <- begin_iteration + nrounds - 1L
}

# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
if (is.null(label)) {
stop("Labels must be provided for lgb.cv")
}
data <- lgb.Dataset(data, label = label)
}

# Check for weights
if (!is.null(weight)) {
data$setinfo("weight", weight)
Expand Down
2 changes: 1 addition & 1 deletion R-package/R/lgb.importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
lgb.importance <- function(model, percentage = TRUE) {

# Check if model is a lightgbm model
if (!inherits(model, "lgb.Booster")) {
if (!lgb.is.Booster(model)) {
stop("'model' has to be an object of class lgb.Booster")
}

Expand Down
45 changes: 17 additions & 28 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ lgb.train <- function(params = list(),
reset_data = FALSE,
...) {

# validate inputs early to avoid unnecessary computation
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data must be an lgb.Dataset instance")
}
if (length(valids) > 0L) {
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1L)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements")
}
evnames <- names(valids)
if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of valids must have a name")
}
}

# Setup temporary variables
additional_params <- list(...)
params <- append(params, additional_params)
Expand All @@ -74,10 +91,6 @@ lgb.train <- function(params = list(),
fobj <- NULL
feval <- NULL

if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}

# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down Expand Up @@ -112,30 +125,6 @@ lgb.train <- function(params = list(),
end_iteration <- begin_iteration + nrounds - 1L
}

# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object")
}

# Check for validation dataset type correctness
if (length(valids) > 0L) {

# One or more validation dataset

# Check for list as input and type correctness by object
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1L)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements")
}

# Attempt to get names
evnames <- names(valids)

# Check for names existance
if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of the valids must have a name tag")
}
}

# Update parameters with parsed parameters
data$update_params(params)

Expand Down
14 changes: 10 additions & 4 deletions R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training
#' @param callbacks list of callback functions
#' List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
#' may allow you to pass other types of data like \code{matrix} and then separately supply
#' \code{label} as a keyword argument.
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#' and one metric. If there's more than one, will check all of them
#' except the training data. Returns the model with (best_iter + early_stopping_rounds).
Expand Down Expand Up @@ -57,11 +60,14 @@ lightgbm <- function(data,
callbacks = list(),
...) {

# Set data to a temporary variable
dtrain <- data
# validate inputs early to avoid unnecessary computation
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}

# Set data to a temporary variable
dtrain <- data

# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
if (!lgb.is.Dataset(dtrain)) {
dtrain <- lgb.Dataset(data, label = label, weight = weight)
Expand Down
6 changes: 4 additions & 2 deletions R-package/man/lgb.cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion R-package/man/lgb.train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion R-package/man/lgb_shared_params.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion R-package/man/lightgbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d906428

Please sign in to comment.