|
18 | 18 | package org.apache.spark.mllib.classification
|
19 | 19 |
|
20 | 20 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
|
21 |
| -import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels |
| 21 | +import breeze.numerics.{exp => brzExp, log => brzLog} |
22 | 22 |
|
23 | 23 | import org.apache.spark.{SparkException, Logging}
|
24 | 24 | import org.apache.spark.SparkContext._
|
25 | 25 | import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
|
26 | 26 | import org.apache.spark.mllib.regression.LabeledPoint
|
| 27 | +import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels |
27 | 28 | import org.apache.spark.rdd.RDD
|
28 | 29 |
|
29 | 30 |
|
@@ -52,29 +53,14 @@ class NaiveBayesModel private[mllib] (
|
52 | 53 | val theta: Array[Array[Double]],
|
53 | 54 | val model: NaiveBayesModels) extends ClassificationModel with Serializable {
|
54 | 55 |
|
55 |
| - def populateMatrix(arrayIn: Array[Array[Double]], |
56 |
| - matrixIn: BDM[Double], |
57 |
| - transformation: (Double) => Double = (x) => x) = { |
58 |
| - var i = 0 |
59 |
| - while (i < arrayIn.length) { |
60 |
| - var j = 0 |
61 |
| - while (j < arrayIn(i).length) { |
62 |
| - matrixIn(i, j) = transformation(theta(i)(j)) |
63 |
| - j += 1 |
64 |
| - } |
65 |
| - i += 1 |
66 |
| - } |
67 |
| - } |
68 |
| - |
69 | 56 | private val brzPi = new BDV[Double](pi)
|
70 |
| - private val brzTheta = new BDM[Double](theta.length, theta(0).length) |
71 |
| - populateMatrix(theta, brzTheta) |
| 57 | + private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t |
72 | 58 |
|
73 | 59 | private val brzNegTheta: Option[BDM[Double]] = model match {
|
74 | 60 | case NaiveBayesModels.Multinomial => None
|
75 | 61 | case NaiveBayesModels.Bernoulli =>
|
76 |
| - val negTheta = new BDM[Double](theta.length, theta(0).length) |
77 |
| - populateMatrix(theta, negTheta, (x) => math.log(1.0 - math.exp(x))) |
| 62 | + val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) |
| 63 | + //((x) => math.log(1.0 - math.exp(x)) |
78 | 64 | Option(negTheta)
|
79 | 65 | }
|
80 | 66 |
|
@@ -244,7 +230,7 @@ object NaiveBayes {
|
244 | 230 | * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be
|
245 | 231 | * Multinomial or Bernoulli
|
246 | 232 | */
|
247 |
| - def train(input: RDD[LabeledPoint], lambda: Double, model: NaiveBayesModels): NaiveBayesModel = { |
248 |
| - new NaiveBayes(lambda, model).run(input) |
| 233 | + def train(input: RDD[LabeledPoint], lambda: Double, model: String): NaiveBayesModel = { |
| 234 | + new NaiveBayes(lambda, NaiveBayesModels.withName(model)).run(input) |
249 | 235 | }
|
250 | 236 | }
|
0 commit comments