Skip to content

Commit f6e8e9a

Browse files
committed
add sparse vectors test
1 parent 4cfbadf commit f6e8e9a

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.scalatest.FunSuite
2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.mllib.util.LocalSparkContext
2323
import org.apache.spark.mllib.util.MLUtils._
24+
import scala.collection.mutable.ArrayBuffer
2425

2526
class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
2627
import VectorRDDFunctionsSuite._
@@ -31,19 +32,47 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3132
Vectors.dense(7.0, 8.0, 9.0)
3233
)
3334

35+
val sparseData = ArrayBuffer(Vectors.sparse(20, Seq((0, 1.0), (9, 2.0), (10, 7.0))))
36+
for (i <- 0 to 10000) sparseData += Vectors.sparse(20, Seq((9, 0.0)))
37+
sparseData += Vectors.sparse(20, Seq((0, 5.0), (9, 13.0), (16, 2.0)))
38+
sparseData += Vectors.sparse(20, Seq((3, 5.0), (9, 13.0), (18, 2.0)))
39+
3440
test("full-statistics") {
3541
val data = sc.parallelize(localData, 2)
36-
val VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min) = data.summarizeStatistics(3)
42+
val (VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min), denseTime) = time(data.summarizeStatistics(3))
3743
assert(equivVector(mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
3844
assert(equivVector(variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
3945
assert(cnt === 3, "Column cnt do not match.")
4046
assert(equivVector(nnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
4147
assert(equivVector(max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
4248
assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
49+
50+
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
51+
val (VectorRDDStatisticalSummary(sparseMean, sparseVariance, sparseCnt, sparseNnz, sparseMax, sparseMin), sparseTime) = time(dataForSparse.summarizeStatistics(20))
52+
/*
53+
assert(equivVector(sparseMean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
54+
assert(equivVector(sparseVariance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
55+
assert(sparseCnt === 3, "Column cnt do not match.")
56+
assert(equivVector(sparseNnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
57+
assert(equivVector(sparseMax, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
58+
assert(equivVector(sparseMin, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
59+
*/
60+
61+
62+
63+
println(s"dense time is $denseTime, sparse time is $sparseTime.")
4364
}
65+
4466
}
4567

4668
object VectorRDDFunctionsSuite {
69+
def time[R](block: => R): (R, Double) = {
70+
val t0 = System.nanoTime()
71+
val result = block
72+
val t1 = System.nanoTime()
73+
(result, (t1 - t0).toDouble / 1.0e9)
74+
}
75+
4776
def equivVector(lhs: Vector, rhs: Vector): Boolean = {
4877
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9
4978
}

0 commit comments

Comments
 (0)