-
-
Notifications
You must be signed in to change notification settings - Fork 15
/
LearnerClassifNaiveBayes.R
70 lines (63 loc) · 1.92 KB
/
LearnerClassifNaiveBayes.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#' @title Naive Bayes Classification Learner
#'
#' @name mlr_learners_classif.naive_bayes
#'
#' @description
#' Naive Bayes classification.
#' Calls [e1071::naiveBayes()] from package \CRANpkg{e1071}.
#'
#' @templateVar id classif.naive_bayes
#' @template learner
#'
#' @export
#' @template seealso_learner
#' @template example
LearnerClassifNaiveBayes = R6Class("LearnerClassifNaiveBayes",
inherit = LearnerClassif,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
eps = p_dbl(default = 0, tags = "predict"),
laplace = p_dbl(0, default = 0, tags = "train"),
threshold = p_dbl(default = 0.001, tags = "predict")
)
super$initialize(
id = "classif.naive_bayes",
param_set = ps,
predict_types = c("response", "prob"),
properties = c("twoclass", "multiclass"),
feature_types = c("logical", "integer", "numeric", "factor"),
packages = c("mlr3learners", "e1071"),
label = "Naive Bayes",
man = "mlr3learners::mlr_learners_classif.naive_bayes"
)
}
),
private = list(
.train = function(task) {
y = task$truth()
x = task$data(cols = task$feature_names)
invoke(e1071::naiveBayes,
x = x, y = y,
.args = self$param_set$get_values(tags = "train"))
},
.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
newdata = ordered_features(task, self)
if (self$predict_type == "response") {
response = invoke(predict, self$model,
newdata = newdata,
type = "class", .args = pv)
list(response = response)
} else {
prob = invoke(predict, self$model, newdata = newdata,
type = "raw", .args = pv)
list(prob = prob)
}
}
)
)
#' @include aaa.R
learners[["classif.naive_bayes"]] = LearnerClassifNaiveBayes