17
17
18
18
package org .apache .spark .mllib .rdd
19
19
20
- import org .apache .spark .mllib .linalg .Vector
21
20
import org .scalatest .FunSuite
22
-
23
- import org .apache .spark .mllib .linalg .Vectors
24
- import org .apache .spark .mllib .util .MLUtils ._
25
- import VectorRDDFunctionsSuite ._
21
+ import org .apache .spark .mllib .linalg .{Vector , Vectors }
26
22
import org .apache .spark .mllib .util .LocalSparkContext
23
+ import org .apache .spark .mllib .util .MLUtils ._
27
24
28
25
class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
26
+ import VectorRDDFunctionsSuite ._
29
27
30
28
val localData = Array (
31
- Vectors .dense(1.0 , 2.0 , 3.0 ),
32
- Vectors .dense(4.0 , 5.0 , 6.0 ),
33
- Vectors .dense(7.0 , 8.0 , 9.0 )
34
- )
29
+ Vectors .dense(1.0 , 2.0 , 3.0 ),
30
+ Vectors .dense(4.0 , 5.0 , 6.0 ),
31
+ Vectors .dense(7.0 , 8.0 , 9.0 )
32
+ )
35
33
36
34
val rowMeans = Array (2.0 , 5.0 , 8.0 )
37
35
val rowNorm2 = Array (math.sqrt(14.0 ), math.sqrt(77.0 ), math.sqrt(194.0 ))
@@ -44,6 +42,23 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
44
42
val maxVec = Array (7.0 , 8.0 , 9.0 )
45
43
val minVec = Array (1.0 , 2.0 , 3.0 )
46
44
45
+ val shrinkingData = Array (
46
+ Vectors .dense(1.0 , 2.0 , 0.0 ),
47
+ Vectors .dense(0.0 , 0.0 , 0.0 ),
48
+ Vectors .dense(7.0 , 8.0 , 0.0 )
49
+ )
50
+
51
+ val rowShrinkData = Array (
52
+ Vectors .dense(1.0 , 2.0 , 0.0 ),
53
+ Vectors .dense(7.0 , 8.0 , 0.0 )
54
+ )
55
+
56
+ val colShrinkData = Array (
57
+ Vectors .dense(1.0 , 2.0 ),
58
+ Vectors .dense(0.0 , 0.0 ),
59
+ Vectors .dense(7.0 , 8.0 )
60
+ )
61
+
47
62
test(" rowMeans" ) {
48
63
val data = sc.parallelize(localData, 2 )
49
64
assert(equivVector(Vectors .dense(data.rowMeans().collect()), Vectors .dense(rowMeans)), " Row means do not match." )
@@ -91,6 +106,22 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
91
106
" Optional minimum does not match."
92
107
)
93
108
}
109
+
110
+ test(" rowShrink" ) {
111
+ val data = sc.parallelize(shrinkingData, 2 )
112
+ val res = data.rowShrink().collect()
113
+ rowShrinkData.zip(res).foreach { case (lhs, rhs) =>
114
+ assert(equivVector(lhs, rhs), " Row shrink error." )
115
+ }
116
+ }
117
+
118
+ test(" columnShrink" ) {
119
+ val data = sc.parallelize(shrinkingData, 2 )
120
+ val res = data.colShrink().collect()
121
+ colShrinkData.zip(res).foreach { case (lhs, rhs) =>
122
+ assert(equivVector(lhs, rhs), " Column shrink error." )
123
+ }
124
+ }
94
125
}
95
126
96
127
object VectorRDDFunctionsSuite {
0 commit comments