Skip to content

Commit

Permalink
import predict
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Dec 6, 2020
1 parent 8b00da3 commit 6b00dea
Show file tree
Hide file tree
Showing 22 changed files with 27 additions and 25 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ importFrom(R6,R6Class)
importFrom(mlr3,LearnerClassif)
importFrom(mlr3,LearnerRegr)
importFrom(mlr3,mlr_learners)
importFrom(stats,predict)
importFrom(utils,bibentry)
4 changes: 2 additions & 2 deletions R/LearnerClassifCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet",
}

if (self$predict_type == "response") {
response = mlr3misc::invoke(stats::predict, self$model,
response = mlr3misc::invoke(predict, self$model,
newx = newdata, type = "class",
.args = pars)

list(response = drop(response))
} else {
prob = mlr3misc::invoke(stats::predict, self$model,
prob = mlr3misc::invoke(predict, self$model,
newx = newdata, type = "response",
.args = pars)

Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClassifGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet",
}

if (self$predict_type == "response") {
response = mlr3misc::invoke(stats::predict, self$model,
response = mlr3misc::invoke(predict, self$model,
newx = newdata, type = "class",
.args = pars)
list(response = drop(response))
} else {
prob = mlr3misc::invoke(stats::predict, self$model,
prob = mlr3misc::invoke(predict, self$model,
newx = newdata, type = "response",
.args = pars)

Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifLDA.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ LearnerClassifLDA = R6Class("LearnerClassifLDA",
pars$predict.prior = NULL
}
newdata = task$data(cols = task$feature_names)
p = mlr3misc::invoke(stats::predict, self$model,
p = mlr3misc::invoke(predict, self$model,
newdata = newdata,
.args = self$param_set$get_values(tags = "predict"))

Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifLogReg.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ LearnerClassifLogReg = R6Class("LearnerClassifLogReg",
.predict = function(task) {
newdata = task$data(cols = task$feature_names)

p = unname(stats::predict(self$model, newdata = newdata, type = "response"))
p = unname(predict(self$model, newdata = newdata, type = "response"))
levs = levels(self$model$data[[task$target_names]])

if (self$predict_type == "response") {
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClassifMultinom.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ LearnerClassifMultinom = R6Class("LearnerClassifMultinom",
levs = task$class_names

if (self$predict_type == "response") {
response = mlr3misc::invoke(stats::predict, self$model, newdata = newdata, type = "class")
response = mlr3misc::invoke(predict, self$model, newdata = newdata, type = "class")
list(response = drop(response))
} else {
prob = mlr3misc::invoke(stats::predict, self$model, newdata = newdata, type = "probs")
prob = mlr3misc::invoke(predict, self$model, newdata = newdata, type = "probs")
if (length(levs) == 2L) {
prob = matrix(c(1 - prob, prob), ncol = 2L, byrow = FALSE)
colnames(prob) = levs
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClassifNaiveBayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ LearnerClassifNaiveBayes = R6Class("LearnerClassifNaiveBayes",
newdata = task$data(cols = task$feature_names)

if (self$predict_type == "response") {
response = mlr3misc::invoke(stats::predict, self$model,
response = mlr3misc::invoke(predict, self$model,
newdata = newdata,
type = "class", .args = pars)
list(response = response)
} else {
prob = mlr3misc::invoke(stats::predict, self$model, newdata = newdata,
prob = mlr3misc::invoke(predict, self$model, newdata = newdata,
type = "raw", .args = pars)
list(prob = prob)
}
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifQDA.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ LearnerClassifQDA = R6Class("LearnerClassifQDA",
}

newdata = task$data(cols = task$feature_names)
p = mlr3misc::invoke(stats::predict, self$model, newdata = newdata, .args = pars)
p = mlr3misc::invoke(predict, self$model, newdata = newdata, .args = pars)

if (self$predict_type == "response") {
list(response = p$class)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger",
.predict = function(task) {
pars = self$param_set$get_values(tags = "predict")
newdata = task$data(cols = task$feature_names)
p = mlr3misc::invoke(stats::predict, self$model,
p = mlr3misc::invoke(predict, self$model,
data = newdata,
predict.type = "response", .args = pars)

Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifSVM.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ LearnerClassifSVM = R6Class("LearnerClassifSVM",
pars = self$param_set$get_values(tags = "predict")
newdata = as.matrix(task$data(cols = task$feature_names))
newdata = newdata[, self$state$feature_names, drop = FALSE]
p = mlr3misc::invoke(stats::predict, self$model,
p = mlr3misc::invoke(predict, self$model,
newdata = newdata,
probability = (self$predict_type == "prob"), .args = pars)

Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",

newdata = data.matrix(task$data(cols = task$feature_names))
newdata = newdata[, model$feature_names, drop = FALSE]
pred = mlr3misc::invoke(stats::predict, model, newdata = newdata, .args = pars)
pred = mlr3misc::invoke(predict, model, newdata = newdata, .args = pars)

if (nlvls == 2L) { # binaryclass
if (pars$objective == "multi:softprob") {
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet",
pars$predict.gamma = NULL
}

response = invoke(stats::predict, self$model, newx = newdata,
response = invoke(predict, self$model, newx = newdata,
type = "response", .args = pars)
list(response = drop(response))
}
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet",
pars$s = self$param_set$default$s
}

response = mlr3misc::invoke(stats::predict, self$model,
response = mlr3misc::invoke(predict, self$model,
newx = newdata,
type = "response", .args = pars)
list(response = drop(response))
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerRegrLM.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ LearnerRegrLM = R6Class("LearnerRegrLM",
newdata = task$data(cols = task$feature_names)

if (self$predict_type == "response") {
response = stats::predict(self$model, newdata = newdata, se.fit = FALSE)
response = predict(self$model, newdata = newdata, se.fit = FALSE)
list(response = response)
} else {
pred = stats::predict(self$model, newdata = newdata, se.fit = TRUE)
pred = predict(self$model, newdata = newdata, se.fit = TRUE)
list(response = pred$fit, se = pred$se.fit)
}
}
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
.predict = function(task) {
pars = self$param_set$get_values(tags = "predict")
newdata = task$data(cols = task$feature_names)
preds = mlr3misc::invoke(stats::predict, self$model,
preds = mlr3misc::invoke(predict, self$model,
data = newdata,
type = self$predict_type, .args = pars)
list(response = preds$predictions, se = preds$se)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrSVM.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ LearnerRegrSVM = R6Class("LearnerRegrSVM",
pars = self$param_set$get_values(tags = "predict")
newdata = as.matrix(task$data(cols = task$feature_names))
newdata = newdata[, self$state$feature_names, drop = FALSE]
response = invoke(stats::predict, self$model, newdata = newdata, type = "response", .args = pars)
response = invoke(predict, self$model, newdata = newdata, type = "response", .args = pars)
list(response = response)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
model = self$model
newdata = data.matrix(task$data(cols = task$feature_names))
newdata = newdata[, model$feature_names, drop = FALSE]
response = invoke(stats::predict, model, newdata = newdata, .args = pars)
response = invoke(predict, model, newdata = newdata, .args = pars)

list(response = response)
}
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerSurvCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ LearnerSurvCVGlmnet = R6Class("LearnerSurvCVGlmnet",
pars$predict.gamma = NULL
}

lp = as.numeric(invoke(stats::predict, self$model, newx = newdata, type = "link", .args = pars))
lp = as.numeric(invoke(predict, self$model, newx = newdata, type = "link", .args = pars))
list(lp = lp, crank = lp)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerSurvGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ LearnerSurvGlmnet = R6Class("LearnerSurvGlmnet",
pars$s = self$param_set$default$s
}

lp = invoke(stats::predict, self$model, newx = newdata, type = "link", .args = pars)
lp = invoke(predict, self$model, newx = newdata, type = "link", .args = pars)

list(crank = lp, lp = lp)
}
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerSurvRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ LearnerSurvRanger = R6Class("LearnerSurvRanger",

.predict = function(task) {
newdata = task$data(cols = task$feature_names)
fit = stats::predict(object = self$model, data = newdata)
fit = predict(object = self$model, data = newdata)
mlr3proba::.surv_return(times = fit$unique.death.times, surv = fit$survival)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerSurvXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ LearnerSurvXgboost = R6Class("LearnerSurvXgboost",
model = self$model
newdata = data.matrix(task$data(cols = task$feature_names))
newdata = newdata[, model$feature_names, drop = FALSE]
lp = log(mlr3misc::invoke(stats::predict, model, newdata = newdata, .args = pars))
lp = log(mlr3misc::invoke(predict, model, newdata = newdata, .args = pars))

list(crank = lp, lp = lp)
}
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' @import mlr3misc
#' @importFrom R6 R6Class
#' @importFrom mlr3 mlr_learners LearnerClassif LearnerRegr
#' @importFrom stats predict
#'
#' @description
#' More learners are implemented in the [mlr3extralearners package](https://github.com/mlr-org/mlr3extralearners).
Expand Down

0 comments on commit 6b00dea

Please sign in to comment.