forked from business-science/modeltime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils-xgboost.R
53 lines (46 loc) · 1.89 KB
/
utils-xgboost.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# XGBOOST UTILITIES ----
#' Wrapper for parsnip::xgb_train
#'
#'
#' @inheritParams parsnip::xgb_train
#' @param validation A positive number. If on `[0, 1)` the value, `validation`
#' is a random proportion of data in `x` and `y` that are used for performance
#' assessment and potential early stopping. If 1 or greater, it is the _number_
#' of training set samples use for these purposes.
#' @param early_stop An integer or `NULL`. If not `NULL`, it is the number of
#' training iterations without improvement before stopping. If `validation` is
#' used, performance is base on the validation set; otherwise the training set
#' is used.
#'
#' @export
xgboost_impl <- function(x, y,
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
early_stop = NULL, ...) {
parsnip::xgb_train(x, y,
max_depth = max_depth, nrounds = nrounds, eta = eta, colsample_bytree = colsample_bytree,
min_child_weight = min_child_weight, gamma = gamma, subsample = subsample,
validation = validation,
early_stop = early_stop, ...)
}
#' Wrapper for xgboost::predict
#'
#' @inheritParams stats::predict
#' @param newdata New data to be predicted
#'
#' @export
xgboost_predict <- function(object, newdata, ...) {
if (!inherits(newdata, "xgb.DMatrix")) {
newdata <- as.matrix(newdata)
newdata <- xgboost::xgb.DMatrix(data = newdata, missing = NA)
}
res <- stats::predict(object, newdata, ...)
x = switch(
object$params$objective,
"reg:linear" = , "reg:logistic" = , "binary:logistic" = res,
"binary:logitraw" = stats::binomial()$linkinv(res),
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
res
)
x
}