Skip to content

Commit a675881

Browse files
committed
minimiaze import of ml
minimiaze import of ml
1 parent a5bbf16 commit a675881

File tree

5 files changed

+18
-18
lines changed

5 files changed

+18
-18
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package org.apache.spark.mllib.evaluation
1919

2020
import org.apache.spark.annotation.Since
2121
import org.apache.spark.internal.Logging
22-
import org.apache.spark.ml.stat.SummaryBuilderImpl._
2322
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.mllib.stat.Statistics
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.{DataFrame, Row}
2626

@@ -60,16 +60,14 @@ class RegressionMetrics @Since("2.0.0") (
6060
* Use SummarizerBuffer to calculate summary statistics of observations and errors.
6161
*/
6262
private lazy val summary = {
63-
predictionAndObservations.map {
63+
val weightedVectors = predictionAndObservations.map {
6464
case (prediction: Double, observation: Double, weight: Double) =>
6565
(Vectors.dense(observation, observation - prediction, prediction), weight)
6666
case (prediction: Double, observation: Double) =>
6767
(Vectors.dense(observation, observation - prediction, prediction), 1.0)
68-
}.treeAggregate(createSummarizerBuffer("mean", "normL1", "normL2", "variance"))(
69-
seqOp = { case (c, (v, w)) => c.add(v.nonZeroIterator, v.size, w) },
70-
combOp = { case (c1, c2) => c1.merge(c2) },
71-
depth = 2
72-
)
68+
}
69+
Statistics.colStats(weightedVectors,
70+
Seq("mean", "normL1", "normL2", "variance"))
7371
}
7472

7573
private lazy val SSy = math.pow(summary.normL2(0), 2)

mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
4646
s"source vector size $numFeatures must be no less than k=$k")
4747

4848
val mat = if (numFeatures > 65535) {
49-
val summary = Statistics.colStats(sources, Seq("mean"))
50-
val meanVector = summary.mean.asBreeze
51-
val meanCentredRdd = sources.map { rowVector =>
52-
Vectors.fromBreeze(rowVector.asBreeze - meanVector)
49+
val summary = Statistics.colStats(sources.map((_, 1.0)), Seq("mean"))
50+
val mean = Vectors.fromML(summary.mean)
51+
val meanCentredRdd = sources.map { row =>
52+
BLAS.axpy(-1, mean, row)
53+
row
5354
}
5455
new RowMatrix(meanCentredRdd)
5556
} else {

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) exten
5555
@Since("1.1.0")
5656
def fit(data: RDD[Vector]): StandardScalerModel = {
5757
// TODO: skip computation if both withMean and withStd are false
58-
val summary = Statistics.colStats(data, Seq("mean", "std"))
58+
val summary = Statistics.colStats(data.map((_, 1.0)), Seq("mean", "std"))
5959

6060
new StandardScalerModel(
6161
Vectors.fromML(summary.std),

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ class RowMatrix @Since("1.0.0") (
433433
val n = numCols().toInt
434434
checkNumColumns(n)
435435

436-
val summary = Statistics.colStats(rows, Seq("count", "mean"))
436+
val summary = Statistics.colStats(rows.map((_, 1.0)), Seq("count", "mean"))
437437
val m = summary.count
438438
require(m > 1, s"RowMatrix.computeCovariance called on matrix with only $m rows." +
439439
" Cannot compute the covariance of a RowMatrix with <= 1 row.")
@@ -616,7 +616,7 @@ class RowMatrix @Since("1.0.0") (
616616
10 * math.log(numCols()) / threshold
617617
}
618618

619-
val summary = Statistics.colStats(rows, Seq("normL2"))
619+
val summary = Statistics.colStats(rows.map((_, 1.0)), Seq("normL2"))
620620
columnSimilaritiesDIMSUM(summary.normL2.toArray, gamma)
621621
}
622622

mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,15 @@ object Statistics {
4848
}
4949

5050
/**
51-
* Computes required column-wise summary statistics for the input RDD[Vector].
51+
* Computes required column-wise summary statistics for the input RDD[(Vector, Double)].
5252
*
53-
* @param X an RDD[Vector] for which column-wise summary statistics are to be computed.
53+
* @param X an RDD containing vectors and weights for which column-wise summary statistics
54+
* are to be computed.
5455
* @return [[SummarizerBuffer]] object containing column-wise summary statistics.
5556
*/
56-
private[mllib] def colStats(X: RDD[Vector], requested: Seq[String]) = {
57+
private[mllib] def colStats(X: RDD[(Vector, Double)], requested: Seq[String]) = {
5758
X.treeAggregate(createSummarizerBuffer(requested: _*))(
58-
seqOp = { case (c, v) => c.add(v.nonZeroIterator, v.size, 1.0) },
59+
seqOp = { case (c, (v, w)) => c.add(v.nonZeroIterator, v.size, w) },
5960
combOp = { case (c1, c2) => c1.merge(c2) },
6061
depth = 2
6162
)

0 commit comments

Comments
 (0)