Skip to content

add bagged neural networks #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ Suggests:
AmesHousing,
covr,
modeldata,
nnet,
recipes,
rmarkdown,
spelling,
testthat (>= 3.0.0),
yardstick
Remotes:
tidymodels/parsnip#815
Config/Needs/website: tidyverse/tidytemplate
Config/testthat/edition: 3
Encoding: UTF-8
Language: en-US
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.2.9000
RoxygenNote: 7.2.1.9000
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export("%>%")
export(bagger)
export(class_cost)
export(control_bag)
export(nnet_imp_garson)
export(var_imp)
export(var_imp.bagger)
import(dplyr)
Expand Down Expand Up @@ -53,6 +54,7 @@ importFrom(rpart,rpart)
importFrom(rsample,analysis)
importFrom(rsample,assessment)
importFrom(rsample,bootstraps)
importFrom(stats,coef)
importFrom(stats,complete.cases)
importFrom(stats,predict)
importFrom(stats,sd)
Expand Down
4 changes: 3 additions & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#' @importFrom rpart rpart
#' @importFrom withr with_seed
#' @importFrom dials new_quant_param

#' @importFrom stats coef
#'
# ------------------------------------------------------------------------------

utils::globalVariables(
Expand Down Expand Up @@ -53,4 +54,5 @@ utils::globalVariables(
# This defines model functions in the parsnip model database
make_bag_tree()
make_bag_mars()
make_bag_mlp()
}
137 changes: 137 additions & 0 deletions R/bag_nnet_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# These functions are tested indirectly when the models are used. Since this
# function is executed on package startup, you can't execute them to test since
# they are already in the parsnip model database. We'll exclude them from
# coverage stats for this reason.

# nocov

make_bag_mlp <- function() {

parsnip::set_model_engine("bag_mlp", "classification", "nnet")
parsnip::set_model_engine("bag_mlp", "regression", "nnet")
parsnip::set_dependency("bag_mlp", "nnet", "nnet", mode = "classification")
parsnip::set_dependency("bag_mlp", "nnet", "nnet", mode = "regression")
parsnip::set_dependency("bag_mlp", "nnet", "baguette", mode = "classification")
parsnip::set_dependency("bag_mlp", "nnet", "baguette", mode = "regression")

parsnip::set_model_arg(
model = "bag_mlp",
eng = "nnet",
parsnip = "hidden_units",
original = "size",
func = list(pkg = "dials", fun = "hidden_units"),
has_submodel = FALSE
)
parsnip::set_model_arg(
model = "bag_mlp",
eng = "nnet",
parsnip = "penalty",
original = "decay",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

parsnip::set_model_arg(
model = "bag_mlp",
eng = "nnet",
parsnip = "epochs",
original = "maxit",
func = list(pkg = "dials", fun = "epochs"),
has_submodel = FALSE
)

parsnip::set_fit(
model = "bag_mlp",
eng = "nnet",
mode = "regression",
value = list(
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "baguette", fun = "bagger"),
defaults = list(base_model = "nnet")
)
)

parsnip::set_encoding(
model = "bag_mlp",
eng = "nnet",
mode = "regression",
options = list(
predictor_indicators = "traditional",
compute_intercept = TRUE,
remove_intercept = TRUE,
allow_sparse_x = FALSE
)
)

parsnip::set_fit(
model = "bag_mlp",
eng = "nnet",
mode = "classification",
value = list(
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "baguette", fun = "bagger"),
defaults = list(base_model = "nnet")
)
)

parsnip::set_encoding(
model = "bag_mlp",
eng = "nnet",
mode = "classification",
options = list(
predictor_indicators = "traditional",
compute_intercept = TRUE,
remove_intercept = TRUE,
allow_sparse_x = FALSE
)
)

parsnip::set_pred(
model = "bag_mlp",
eng = "nnet",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(object = quote(object$fit), new_data = quote(new_data))
)
)

parsnip::set_pred(
model = "bag_mlp",
eng = "nnet",
mode = "classification",
type = "class",
value = list(
pre = NULL,
post = fix_column_names,
func = c(pkg = NULL, fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
type = "class"
)
)
)

parsnip::set_pred(
model = "bag_mlp",
eng = "nnet",
mode = "classification",
type = "prob",
value = list(
pre = NULL,
post = fix_column_names,
func = c(pkg = NULL, fun = "predict"),
args = list(object = quote(object$fit), new_data = quote(new_data), type = "prob")
)
)

}

# nocov end
9 changes: 8 additions & 1 deletion R/bagger.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' @param weights A numeric vector of non-negative case weights. These values are
#' not used during bootstrap resampling.
#' @param base_model A single character value for the model being bagged. Possible
#' values are "CART", "MARS", and "C5.0" (classification only).
#' values are "CART", "MARS", "nnet", and "C5.0" (classification only).
#' @param times A single integer greater than 1 for the maximum number of bootstrap
#' samples/ensemble members (some model fits might fail).
#' @param control A list of options generated by `control_bag()`.
Expand All @@ -36,6 +36,13 @@
#' enable parallelism, use the `future::plan()` function to declare _how_ the
#' computations should be distributed. Note that this will almost certainly
#' multiply the memory requirements required to fit the models.
#'
#' For neural networks, variable importance is calculated using the method
#' of Garson described in Gevrey _et al_ (2003)
#'
#' @references Gevrey, M., Dimopoulos, I., and Lek, S. (2003). Review and
#' comparison of methods to study the contribution of variables in artificial
#' neural network models. Ecological Modelling, 160(3), 249-264.
#' @examples
#' library(recipes)
#' library(dplyr)
Expand Down
3 changes: 2 additions & 1 deletion R/bridge.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ bagger_bridge <- function(processed, weights, base_model, seed, times, control,
base_model,
CART = cart_bagger(rs, control, ...),
C5.0 = c5_bagger(rs, control, ...),
MARS = mars_bagger(rs, control, ...)
MARS = mars_bagger(rs, control, ...),
nnet = nnet_bagger(rs, control, ...)
)
} else {
res <- switch(
Expand Down
2 changes: 1 addition & 1 deletion R/cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ make_cart_spec <- function(classif, opt) {
main_args <- NULL
}

# Note: from ?rpart: "arguments to rpartcontrol may also be specified in
# Note: from ?rpart: "arguments to rpart.control may also be specified in
# the call to rpart. They are checked against the list of valid arguments."

cart_spec <-
Expand Down
10 changes: 6 additions & 4 deletions R/model_info.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
baguette_models <- c("CART", "C5.0", "MARS")
baguette_models <- c("CART", "C5.0", "MARS", "nnet")

# We want to default some arguments for different models
model_defaults <-
list(
CART = list(cp = 0, xval = 0, minsplit = 20, maxdepth = 30, model = FALSE),
"model rules" = list(),
"C5.0" = list(minCases = 2),
MARS = list(pmethod = "none", nprune = NULL, degree = 1)
MARS = list(pmethod = "none", nprune = NULL, degree = 1),
nnet = list(decay = 0, size = 10, maxit = 1000, MaxNWts = 10^5)
)

# Enumerate the possible arguments in the fit or control functions that can
# be modified by the user. This could be done programatically to protect against
# changes but each of the underlying pacakges is pretty mature and there is a
# changes but each of the underlying packages is pretty mature and there is a
# small likelihood of them changing.

model_args <-
Expand All @@ -26,6 +27,7 @@ model_args <-
'CF', 'minCases', 'fuzzyThreshold', 'sample'),
MARS = c('pmethod', 'trace', 'glm', 'degree', 'nprune', 'nfold', 'ncross',
'stratify', 'varmod.method', 'varmod.exponent', 'varmod.conv',
'varmod.clamp', 'varmod.minspan', 'Scale.y')
'varmod.clamp', 'varmod.minspan', 'Scale.y'),
nnet = c("linout", "entropy", "softmax", "censored", "skip", "rang", "maxit")
)

Loading