diff --git a/NAMESPACE b/NAMESPACE index 5f9e1561..93d40257 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -31,4 +31,5 @@ importFrom(R6,R6Class) importFrom(mlr3,LearnerClassif) importFrom(mlr3,LearnerRegr) importFrom(mlr3,mlr_learners) +importFrom(stats,predict) importFrom(utils,bibentry) diff --git a/R/LearnerClassifCVGlmnet.R b/R/LearnerClassifCVGlmnet.R index ba955a99..2a2c4f9e 100644 --- a/R/LearnerClassifCVGlmnet.R +++ b/R/LearnerClassifCVGlmnet.R @@ -126,13 +126,13 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet", } if (self$predict_type == "response") { - response = mlr3misc::invoke(stats::predict, self$model, + response = mlr3misc::invoke(predict, self$model, newx = newdata, type = "class", .args = pars) list(response = drop(response)) } else { - prob = mlr3misc::invoke(stats::predict, self$model, + prob = mlr3misc::invoke(predict, self$model, newx = newdata, type = "response", .args = pars) diff --git a/R/LearnerClassifGlmnet.R b/R/LearnerClassifGlmnet.R index f656c1c7..9882f915 100644 --- a/R/LearnerClassifGlmnet.R +++ b/R/LearnerClassifGlmnet.R @@ -125,12 +125,12 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet", } if (self$predict_type == "response") { - response = mlr3misc::invoke(stats::predict, self$model, + response = mlr3misc::invoke(predict, self$model, newx = newdata, type = "class", .args = pars) list(response = drop(response)) } else { - prob = mlr3misc::invoke(stats::predict, self$model, + prob = mlr3misc::invoke(predict, self$model, newx = newdata, type = "response", .args = pars) diff --git a/R/LearnerClassifLDA.R b/R/LearnerClassifLDA.R index f02d7ff7..e81165a9 100644 --- a/R/LearnerClassifLDA.R +++ b/R/LearnerClassifLDA.R @@ -74,7 +74,7 @@ LearnerClassifLDA = R6Class("LearnerClassifLDA", pars$predict.prior = NULL } newdata = task$data(cols = task$feature_names) - p = mlr3misc::invoke(stats::predict, self$model, + p = mlr3misc::invoke(predict, self$model, newdata = newdata, .args = self$param_set$get_values(tags = "predict")) diff --git a/R/LearnerClassifLogReg.R b/R/LearnerClassifLogReg.R index 0965907f..0a8cb4ed 100644 --- a/R/LearnerClassifLogReg.R +++ b/R/LearnerClassifLogReg.R @@ -71,7 +71,7 @@ LearnerClassifLogReg = R6Class("LearnerClassifLogReg", .predict = function(task) { newdata = task$data(cols = task$feature_names) - p = unname(stats::predict(self$model, newdata = newdata, type = "response")) + p = unname(predict(self$model, newdata = newdata, type = "response")) levs = levels(self$model$data[[task$target_names]]) if (self$predict_type == "response") { diff --git a/R/LearnerClassifMultinom.R b/R/LearnerClassifMultinom.R index 30c1431c..04d8977d 100644 --- a/R/LearnerClassifMultinom.R +++ b/R/LearnerClassifMultinom.R @@ -65,10 +65,10 @@ LearnerClassifMultinom = R6Class("LearnerClassifMultinom", levs = task$class_names if (self$predict_type == "response") { - response = mlr3misc::invoke(stats::predict, self$model, newdata = newdata, type = "class") + response = mlr3misc::invoke(predict, self$model, newdata = newdata, type = "class") list(response = drop(response)) } else { - prob = mlr3misc::invoke(stats::predict, self$model, newdata = newdata, type = "probs") + prob = mlr3misc::invoke(predict, self$model, newdata = newdata, type = "probs") if (length(levs) == 2L) { prob = matrix(c(1 - prob, prob), ncol = 2L, byrow = FALSE) colnames(prob) = levs diff --git a/R/LearnerClassifNaiveBayes.R b/R/LearnerClassifNaiveBayes.R index bd6d0a74..de713625 100644 --- a/R/LearnerClassifNaiveBayes.R +++ b/R/LearnerClassifNaiveBayes.R @@ -52,12 +52,12 @@ LearnerClassifNaiveBayes = R6Class("LearnerClassifNaiveBayes", newdata = task$data(cols = task$feature_names) if (self$predict_type == "response") { - response = mlr3misc::invoke(stats::predict, self$model, + response = mlr3misc::invoke(predict, self$model, newdata = newdata, type = "class", .args = pars) list(response = response) } else { - prob = mlr3misc::invoke(stats::predict, self$model, newdata = newdata, + prob = mlr3misc::invoke(predict, self$model, newdata = newdata, type = "raw", .args = pars) list(prob = prob) } diff --git a/R/LearnerClassifQDA.R b/R/LearnerClassifQDA.R index 225a97d2..d0d022dc 100644 --- a/R/LearnerClassifQDA.R +++ b/R/LearnerClassifQDA.R @@ -73,7 +73,7 @@ LearnerClassifQDA = R6Class("LearnerClassifQDA", } newdata = task$data(cols = task$feature_names) - p = mlr3misc::invoke(stats::predict, self$model, newdata = newdata, .args = pars) + p = mlr3misc::invoke(predict, self$model, newdata = newdata, .args = pars) if (self$predict_type == "response") { list(response = p$class) diff --git a/R/LearnerClassifRanger.R b/R/LearnerClassifRanger.R index 66a2cd70..a5e6e4d8 100644 --- a/R/LearnerClassifRanger.R +++ b/R/LearnerClassifRanger.R @@ -122,7 +122,7 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger", .predict = function(task) { pars = self$param_set$get_values(tags = "predict") newdata = task$data(cols = task$feature_names) - p = mlr3misc::invoke(stats::predict, self$model, + p = mlr3misc::invoke(predict, self$model, data = newdata, predict.type = "response", .args = pars) diff --git a/R/LearnerClassifSVM.R b/R/LearnerClassifSVM.R index 70d7df93..d5445fac 100644 --- a/R/LearnerClassifSVM.R +++ b/R/LearnerClassifSVM.R @@ -80,7 +80,7 @@ LearnerClassifSVM = R6Class("LearnerClassifSVM", pars = self$param_set$get_values(tags = "predict") newdata = as.matrix(task$data(cols = task$feature_names)) newdata = newdata[, self$state$feature_names, drop = FALSE] - p = mlr3misc::invoke(stats::predict, self$model, + p = mlr3misc::invoke(predict, self$model, newdata = newdata, probability = (self$predict_type == "prob"), .args = pars) diff --git a/R/LearnerClassifXgboost.R b/R/LearnerClassifXgboost.R index c24c821a..e2cad777 100644 --- a/R/LearnerClassifXgboost.R +++ b/R/LearnerClassifXgboost.R @@ -223,7 +223,7 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost", newdata = data.matrix(task$data(cols = task$feature_names)) newdata = newdata[, model$feature_names, drop = FALSE] - pred = mlr3misc::invoke(stats::predict, model, newdata = newdata, .args = pars) + pred = mlr3misc::invoke(predict, model, newdata = newdata, .args = pars) if (nlvls == 2L) { # binaryclass if (pars$objective == "multi:softprob") { diff --git a/R/LearnerRegrCVGlmnet.R b/R/LearnerRegrCVGlmnet.R index ea0bc7bd..ab151799 100644 --- a/R/LearnerRegrCVGlmnet.R +++ b/R/LearnerRegrCVGlmnet.R @@ -125,7 +125,7 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet", pars$predict.gamma = NULL } - response = invoke(stats::predict, self$model, newx = newdata, + response = invoke(predict, self$model, newx = newdata, type = "response", .args = pars) list(response = drop(response)) } diff --git a/R/LearnerRegrGlmnet.R b/R/LearnerRegrGlmnet.R index 134be6b6..872487bb 100644 --- a/R/LearnerRegrGlmnet.R +++ b/R/LearnerRegrGlmnet.R @@ -129,7 +129,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet", pars$s = self$param_set$default$s } - response = mlr3misc::invoke(stats::predict, self$model, + response = mlr3misc::invoke(predict, self$model, newx = newdata, type = "response", .args = pars) list(response = drop(response)) diff --git a/R/LearnerRegrLM.R b/R/LearnerRegrLM.R index 5b53fb00..387439bd 100644 --- a/R/LearnerRegrLM.R +++ b/R/LearnerRegrLM.R @@ -65,10 +65,10 @@ LearnerRegrLM = R6Class("LearnerRegrLM", newdata = task$data(cols = task$feature_names) if (self$predict_type == "response") { - response = stats::predict(self$model, newdata = newdata, se.fit = FALSE) + response = predict(self$model, newdata = newdata, se.fit = FALSE) list(response = response) } else { - pred = stats::predict(self$model, newdata = newdata, se.fit = TRUE) + pred = predict(self$model, newdata = newdata, se.fit = TRUE) list(response = pred$fit, se = pred$se.fit) } } diff --git a/R/LearnerRegrRanger.R b/R/LearnerRegrRanger.R index 369d6143..84c239ef 100644 --- a/R/LearnerRegrRanger.R +++ b/R/LearnerRegrRanger.R @@ -135,7 +135,7 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger", .predict = function(task) { pars = self$param_set$get_values(tags = "predict") newdata = task$data(cols = task$feature_names) - preds = mlr3misc::invoke(stats::predict, self$model, + preds = mlr3misc::invoke(predict, self$model, data = newdata, type = self$predict_type, .args = pars) list(response = preds$predictions, se = preds$se) diff --git a/R/LearnerRegrSVM.R b/R/LearnerRegrSVM.R index 11494e0c..309e5412 100644 --- a/R/LearnerRegrSVM.R +++ b/R/LearnerRegrSVM.R @@ -72,7 +72,7 @@ LearnerRegrSVM = R6Class("LearnerRegrSVM", pars = self$param_set$get_values(tags = "predict") newdata = as.matrix(task$data(cols = task$feature_names)) newdata = newdata[, self$state$feature_names, drop = FALSE] - response = invoke(stats::predict, self$model, newdata = newdata, type = "response", .args = pars) + response = invoke(predict, self$model, newdata = newdata, type = "response", .args = pars) list(response = response) } ) diff --git a/R/LearnerRegrXgboost.R b/R/LearnerRegrXgboost.R index 0df706bd..a13f4a83 100644 --- a/R/LearnerRegrXgboost.R +++ b/R/LearnerRegrXgboost.R @@ -192,7 +192,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost", model = self$model newdata = data.matrix(task$data(cols = task$feature_names)) newdata = newdata[, model$feature_names, drop = FALSE] - response = invoke(stats::predict, model, newdata = newdata, .args = pars) + response = invoke(predict, model, newdata = newdata, .args = pars) list(response = response) } diff --git a/R/LearnerSurvCVGlmnet.R b/R/LearnerSurvCVGlmnet.R index 96af478d..263805d0 100644 --- a/R/LearnerSurvCVGlmnet.R +++ b/R/LearnerSurvCVGlmnet.R @@ -124,7 +124,7 @@ LearnerSurvCVGlmnet = R6Class("LearnerSurvCVGlmnet", pars$predict.gamma = NULL } - lp = as.numeric(invoke(stats::predict, self$model, newx = newdata, type = "link", .args = pars)) + lp = as.numeric(invoke(predict, self$model, newx = newdata, type = "link", .args = pars)) list(lp = lp, crank = lp) } ) diff --git a/R/LearnerSurvGlmnet.R b/R/LearnerSurvGlmnet.R index 0091d9a3..7d422757 100644 --- a/R/LearnerSurvGlmnet.R +++ b/R/LearnerSurvGlmnet.R @@ -118,7 +118,7 @@ LearnerSurvGlmnet = R6Class("LearnerSurvGlmnet", pars$s = self$param_set$default$s } - lp = invoke(stats::predict, self$model, newx = newdata, type = "link", .args = pars) + lp = invoke(predict, self$model, newx = newdata, type = "link", .args = pars) list(crank = lp, lp = lp) } diff --git a/R/LearnerSurvRanger.R b/R/LearnerSurvRanger.R index 19201144..dd91e257 100644 --- a/R/LearnerSurvRanger.R +++ b/R/LearnerSurvRanger.R @@ -115,7 +115,7 @@ LearnerSurvRanger = R6Class("LearnerSurvRanger", .predict = function(task) { newdata = task$data(cols = task$feature_names) - fit = stats::predict(object = self$model, data = newdata) + fit = predict(object = self$model, data = newdata) mlr3proba::.surv_return(times = fit$unique.death.times, surv = fit$survival) } ) diff --git a/R/LearnerSurvXgboost.R b/R/LearnerSurvXgboost.R index ef929723..c6149224 100644 --- a/R/LearnerSurvXgboost.R +++ b/R/LearnerSurvXgboost.R @@ -187,7 +187,7 @@ LearnerSurvXgboost = R6Class("LearnerSurvXgboost", model = self$model newdata = data.matrix(task$data(cols = task$feature_names)) newdata = newdata[, model$feature_names, drop = FALSE] - lp = log(mlr3misc::invoke(stats::predict, model, newdata = newdata, .args = pars)) + lp = log(mlr3misc::invoke(predict, model, newdata = newdata, .args = pars)) list(crank = lp, lp = lp) } diff --git a/R/zzz.R b/R/zzz.R index 4ac0028f..52b31d72 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -3,6 +3,7 @@ #' @import mlr3misc #' @importFrom R6 R6Class #' @importFrom mlr3 mlr_learners LearnerClassif LearnerRegr +#' @importFrom stats predict #' #' @description #' More learners are implemented in the [mlr3extralearners package](https://github.com/mlr-org/mlr3extralearners).