Skip to content

Add argument for one hot encoding to workflows #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ Imports:
ellipsis (>= 0.2.0),
generics,
glue,
hardhat (>= 0.1.2),
parsnip (>= 0.0.4),
hardhat (>= 0.1.3.9000),
parsnip (>= 0.1.1.9000),
rlang (>= 0.4.1)
Remotes:
tidymodels/parsnip
tidymodels/parsnip#332,
tidymodels/hardhat
Suggests:
covr,
knitr,
Expand Down
12 changes: 6 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# workflows (development version)

* When using a formula preprocessor with `add_formula()`, workflows now uses
model-specific information from parsnip to decide whether or not to expand
factors into dummy variables. This should result in more intuitive behavior
when working with models that don't require dummy variables. For example,
if a parsnip `rand_forest()` model is used with a ranger engine, dummy
variables will not be created, because ranger can handle factors directly
(#51).
model-specific information from parsnip to decide whether to expand
factors via dummy encoding (`n - 1` levels), one-hot encoding (`n` levels), or
no expansion at all. This should result in more intuitive behavior when
working with models that don't require dummy variables. For example, if a
parsnip `rand_forest()` model is used with a ranger engine, dummy variables
will not be created, because ranger can handle factors directly (#51, #53).

# workflows 0.1.1

Expand Down
33 changes: 24 additions & 9 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#' In the future, there will also be _postprocessing_ steps that can be added
#' after the model has been fit.
#'
#' @includeRmd man/rmd/indicators.Rmd details
#'
#' @param object A workflow
#'
#' @param data A data frame of predictors and outcomes to use when fitting the
Expand Down Expand Up @@ -187,18 +189,31 @@ finalize_blueprint_recipe <- function(workflow) {
}

finalize_blueprint_formula <- function(workflow) {
# Use the model indicators information to construct the blueprint
indicators <- pull_workflow_spec_indicators(workflow)
blueprint <- hardhat::default_formula_blueprint(indicators = indicators)
tbl_encodings <- pull_workflow_spec_encoding_tbl(workflow)

indicators <- tbl_encodings$predictor_indicators
intercept <- tbl_encodings$compute_intercept

if (!is_string(indicators)) {
abort("Internal error: `indicators` encoding from parsnip should be a string.")
}
if (!is_bool(intercept)) {
abort("Internal error: `intercept` encoding from parsnip should be a bool.")
}

# Use model specific information to construct the blueprint
blueprint <- hardhat::default_formula_blueprint(
indicators = indicators,
intercept = intercept
)

formula <- pull_workflow_preprocessor(workflow)

update_formula(workflow, formula = formula, blueprint = blueprint)
}

pull_workflow_spec_indicators <- function(x) {
spec <- pull_workflow_spec(x)

pull_workflow_spec_encoding_tbl <- function(workflow) {
spec <- pull_workflow_spec(workflow)
spec_cls <- class(spec)[[1]]

tbl_encodings <- parsnip::get_encoding(spec_cls)
Expand All @@ -207,11 +222,11 @@ pull_workflow_spec_indicators <- function(x) {
indicator_mode <- tbl_encodings$mode == spec$mode
indicator_spec <- indicator_engine & indicator_mode

indicators <- tbl_encodings$predictor_indicators[indicator_spec]
out <- tbl_encodings[indicator_spec, , drop = FALSE]

if (length(indicators) != 1L) {
if (nrow(out) != 1L) {
abort("Internal error: Exactly 1 model/engine/mode combination must be located.")
}

indicators
out
}
13 changes: 11 additions & 2 deletions R/pre-action-formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#' To fit a workflow, one of `add_formula()` or `add_recipe()` _must_ be
#' specified, but not both.
#'
#' @includeRmd man/rmd/add-formula.Rmd details
#'
#' @param x A workflow
#'
#' @param formula A formula specifying the terms of the model. It is advised to
Expand All @@ -26,9 +28,16 @@
#' @param ... Not used.
#'
#' @param blueprint A hardhat blueprint used for fine tuning the preprocessing.
#'
#' If `NULL`, [hardhat::default_formula_blueprint()] is used and is passed
#' an `indicators` argument that best aligns with the model present in
#' the workflow.
#' arguments that best align with the model present in the workflow.
#'
#' Note that preprocessing done here is separate from preprocessing that
#' might be done by the underlying model. For example, if a blueprint with
#' `indicators = "none"` is specified, no dummy variables will be created by
#' hardhat, but if the underlying model requires a formula interface that
#' internally uses [stats::model.matrix()], factors will still be expanded to
#' dummy variables by the model.
#'
#' @return
#' `x`, updated with either a new or removed formula preprocessor.
Expand Down
4 changes: 4 additions & 0 deletions R/pre-action-recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
#' @param ... Not used.
#'
#' @param blueprint A hardhat blueprint used for fine tuning the preprocessing.
#'
#' If `NULL`, [hardhat::default_recipe_blueprint()] is used.
#'
#' Note that preprocessing done here is separate from preprocessing that
#' might be done automatically by the underlying model.
#'
#' @return
#' `x`, updated with either a new or removed recipe preprocessor.
#'
Expand Down
176 changes: 174 additions & 2 deletions man/add_formula.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions man/add_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/add_recipe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading