Skip to content

Commit ad6c82d

Browse files
committed
add shrink test
1 parent e09d5d2 commit ad6c82d

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,19 @@
1717

1818
package org.apache.spark.mllib.rdd
1919

20-
import org.apache.spark.mllib.linalg.Vector
2120
import org.scalatest.FunSuite
22-
23-
import org.apache.spark.mllib.linalg.Vectors
24-
import org.apache.spark.mllib.util.MLUtils._
25-
import VectorRDDFunctionsSuite._
21+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2622
import org.apache.spark.mllib.util.LocalSparkContext
23+
import org.apache.spark.mllib.util.MLUtils._
2724

2825
class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
26+
import VectorRDDFunctionsSuite._
2927

3028
val localData = Array(
31-
Vectors.dense(1.0, 2.0, 3.0),
32-
Vectors.dense(4.0, 5.0, 6.0),
33-
Vectors.dense(7.0, 8.0, 9.0)
34-
)
29+
Vectors.dense(1.0, 2.0, 3.0),
30+
Vectors.dense(4.0, 5.0, 6.0),
31+
Vectors.dense(7.0, 8.0, 9.0)
32+
)
3533

3634
val rowMeans = Array(2.0, 5.0, 8.0)
3735
val rowNorm2 = Array(math.sqrt(14.0), math.sqrt(77.0), math.sqrt(194.0))
@@ -44,6 +42,23 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4442
val maxVec = Array(7.0, 8.0, 9.0)
4543
val minVec = Array(1.0, 2.0, 3.0)
4644

45+
val shrinkingData = Array(
46+
Vectors.dense(1.0, 2.0, 0.0),
47+
Vectors.dense(0.0, 0.0, 0.0),
48+
Vectors.dense(7.0, 8.0, 0.0)
49+
)
50+
51+
val rowShrinkData = Array(
52+
Vectors.dense(1.0, 2.0, 0.0),
53+
Vectors.dense(7.0, 8.0, 0.0)
54+
)
55+
56+
val colShrinkData = Array(
57+
Vectors.dense(1.0, 2.0),
58+
Vectors.dense(0.0, 0.0),
59+
Vectors.dense(7.0, 8.0)
60+
)
61+
4762
test("rowMeans") {
4863
val data = sc.parallelize(localData, 2)
4964
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)), "Row means do not match.")
@@ -91,6 +106,22 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
91106
"Optional minimum does not match."
92107
)
93108
}
109+
110+
test("rowShrink") {
111+
val data = sc.parallelize(shrinkingData, 2)
112+
val res = data.rowShrink().collect()
113+
rowShrinkData.zip(res).foreach { case (lhs, rhs) =>
114+
assert(equivVector(lhs, rhs), "Row shrink error.")
115+
}
116+
}
117+
118+
test("columnShrink") {
119+
val data = sc.parallelize(shrinkingData, 2)
120+
val res = data.colShrink().collect()
121+
colShrinkData.zip(res).foreach { case (lhs, rhs) =>
122+
assert(equivVector(lhs, rhs), "Column shrink error.")
123+
}
124+
}
94125
}
95126

96127
object VectorRDDFunctionsSuite {

0 commit comments

Comments
 (0)