Skip to content

save x column names from fit_xy() #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 3, 2024
Merged

save x column names from fit_xy() #1168

merged 2 commits into from
Sep 3, 2024

Conversation

EmilHvitfeldt
Copy link
Member

to close #1166.

This bug happened because we don't know the name of the outcome(s) y when using fit_xy() because it is often a nameless vector. So instead, what I tried to do was to save the names of the x and subset with those when appropriate.

This is NOT a xgboost issue. it is just that xgboost complains more loudly than other engines.

library(parsnip)

spec <- boost_tree() %>%
  set_mode("regression") %>%
  set_engine("xgboost")

lm_fit <- fit(spec, mpg ~ ., data = mtcars)

predict(lm_fit, mtcars)
#> # A tibble: 32 × 1
#>    .pred
#>    <dbl>
#>  1  20.9
#>  2  20.9
#>  3  22.6
#>  4  21.0
#>  5  18.4
#>  6  18.1
#>  7  14.2
#>  8  23.7
#>  9  22.4
#> 10  18.9
#> # ℹ 22 more rows

lm_fit <- fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])

predict(lm_fit, mtcars)
#> # A tibble: 32 × 1
#>    .pred
#>    <dbl>
#>  1  20.9
#>  2  20.9
#>  3  22.6
#>  4  21.0
#>  5  18.4
#>  6  18.1
#>  7  14.2
#>  8  23.7
#>  9  22.4
#> 10  18.9
#> # ℹ 22 more rows

Created on 2024-08-30 with reprex v2.1.0

Copy link
Contributor

@simonpcouch simonpcouch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super clean!

My only question was whether there were any unexpected hiccups with non-standard roles, but all is well:

library(tidymodels)

data(biomass, package = "modeldata")
biomass_train <- biomass[1:100,]
biomass_test <- biomass[101:200,]

rec <- recipe(HHV ~ ., data = biomass_train) %>%
  update_role(sample, new_role = "id variable") %>%
  step_center(carbon) %>%
  step_dummy(all_nominal_predictors())

spec <- boost_tree() %>%
  set_mode("regression") %>%
  set_engine("xgboost")

wf <- fit(workflow(rec, spec), biomass_train)

predict(wf, biomass_train)
#> # A tibble: 100 × 1
#>    .pred
#>    <dbl>
#>  1  19.5
#>  2  19.2
#>  3  18.2
#>  4  17.9
#>  5  18.4
#>  6  18.4
#>  7  18.5
#>  8  18.2
#>  9  18.7
#> 10  18.8
#> # ℹ 90 more rows

Created on 2024-09-03 with reprex v2.1.1

@EmilHvitfeldt EmilHvitfeldt merged commit eeaf82f into main Sep 3, 2024
10 checks passed
@EmilHvitfeldt EmilHvitfeldt deleted the fix1166 branch September 3, 2024 23:56
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Sep 18, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

can't predict() with xgboost if fit with fit_xy() if data contains outcome
2 participants