@@ -21,6 +21,7 @@ import org.scalatest.FunSuite
21
21
import org .apache .spark .mllib .linalg .{Vector , Vectors }
22
22
import org .apache .spark .mllib .util .LocalSparkContext
23
23
import org .apache .spark .mllib .util .MLUtils ._
24
+ import scala .collection .mutable .ArrayBuffer
24
25
25
26
class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
26
27
import VectorRDDFunctionsSuite ._
@@ -31,19 +32,47 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
31
32
Vectors .dense(7.0 , 8.0 , 9.0 )
32
33
)
33
34
35
+ val sparseData = ArrayBuffer (Vectors .sparse(20 , Seq ((0 , 1.0 ), (9 , 2.0 ), (10 , 7.0 ))))
36
+ for (i <- 0 to 10000 ) sparseData += Vectors .sparse(20 , Seq ((9 , 0.0 )))
37
+ sparseData += Vectors .sparse(20 , Seq ((0 , 5.0 ), (9 , 13.0 ), (16 , 2.0 )))
38
+ sparseData += Vectors .sparse(20 , Seq ((3 , 5.0 ), (9 , 13.0 ), (18 , 2.0 )))
39
+
34
40
test(" full-statistics" ) {
35
41
val data = sc.parallelize(localData, 2 )
36
- val VectorRDDStatisticalSummary (mean, variance, cnt, nnz, max, min) = data.summarizeStatistics(3 )
42
+ val ( VectorRDDStatisticalSummary (mean, variance, cnt, nnz, max, min), denseTime) = time( data.summarizeStatistics(3 ) )
37
43
assert(equivVector(mean, Vectors .dense(4.0 , 5.0 , 6.0 )), " Column mean do not match." )
38
44
assert(equivVector(variance, Vectors .dense(6.0 , 6.0 , 6.0 )), " Column variance do not match." )
39
45
assert(cnt === 3 , " Column cnt do not match." )
40
46
assert(equivVector(nnz, Vectors .dense(3.0 , 3.0 , 3.0 )), " Column nnz do not match." )
41
47
assert(equivVector(max, Vectors .dense(7.0 , 8.0 , 9.0 )), " Column max do not match." )
42
48
assert(equivVector(min, Vectors .dense(1.0 , 2.0 , 3.0 )), " Column min do not match." )
49
+
50
+ val dataForSparse = sc.parallelize(sparseData.toSeq, 2 )
51
+ val (VectorRDDStatisticalSummary (sparseMean, sparseVariance, sparseCnt, sparseNnz, sparseMax, sparseMin), sparseTime) = time(dataForSparse.summarizeStatistics(20 ))
52
+ /*
53
+ assert(equivVector(sparseMean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
54
+ assert(equivVector(sparseVariance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
55
+ assert(sparseCnt === 3, "Column cnt do not match.")
56
+ assert(equivVector(sparseNnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
57
+ assert(equivVector(sparseMax, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
58
+ assert(equivVector(sparseMin, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
59
+ */
60
+
61
+
62
+
63
+ println(s " dense time is $denseTime, sparse time is $sparseTime. " )
43
64
}
65
+
44
66
}
45
67
46
68
object VectorRDDFunctionsSuite {
69
+ def time [R ](block : => R ): (R , Double ) = {
70
+ val t0 = System .nanoTime()
71
+ val result = block
72
+ val t1 = System .nanoTime()
73
+ (result, (t1 - t0).toDouble / 1.0e9 )
74
+ }
75
+
47
76
def equivVector (lhs : Vector , rhs : Vector ): Boolean = {
48
77
(lhs.toBreeze - rhs.toBreeze).norm(2 ) < 1e-9
49
78
}
0 commit comments