Skip to content

Commit 3f7de7d

Browse files
GeorgeDittmarmengxr
authored andcommitted
[SPARK-7422] [MLLIB] Add argmax to Vector, SparseVector
Modifying Vector, DenseVector, and SparseVector to implement argmax functionality. This work is to set the stage for changes to be done in Spark-7423. Author: George Dittmar <georgedittmar@gmail.com> Author: George <dittmar@Georges-MacBook-Pro.local> Author: dittmarg <george.dittmar@webtrends.com> Author: Xiangrui Meng <meng@databricks.com> Closes apache#6112 from GeorgeDittmar/SPARK-7422 and squashes the following commits: 3e0a939 [George Dittmar] Merge pull request #1 from mengxr/SPARK-7422 127dec5 [Xiangrui Meng] update argmax impl 2ea6a55 [George Dittmar] Added MimaExcludes for Vectors.argmax 98058f4 [George Dittmar] Merge branch 'master' of github.com:apache/spark into SPARK-7422 5fd9380 [George Dittmar] fixing style check error 42341fb [George Dittmar] refactoring arg max check to better handle zero values b22af46 [George Dittmar] Fixing spaces between commas in unit test f2eba2f [George Dittmar] Cleaning up unit tests to be fewer lines aa330e3 [George Dittmar] Fixing some last if else spacing issues ac53c55 [George Dittmar] changing dense vector argmax unit test to be one line call vs 2 d5b5423 [George Dittmar] Fixing code style and updating if logic on when to check for zero values ee1a85a [George Dittmar] Cleaning up unit tests a bit and modifying a few cases 3ee8711 [George Dittmar] Fixing corner case issue with zeros in the active values of the sparse vector. Updated unit tests b1f059f [George Dittmar] Added comment before we start arg max calculation. Updated unit tests to cover corner cases f21dcce [George Dittmar] commit af17981 [dittmarg] Initial work fixing bug that was made clear in pr eeda560 [George] Fixing SparseVector argmax function to ignore zero values while doing the calculation. 4526acc [George] Merge branch 'master' of github.com:apache/spark into SPARK-7422 df9538a [George] Added argmax to sparse vector and added unit test 3cffed4 [George] Adding unit tests for argmax functions for Dense and Sparse vectors 04677af [George] initial work on adding argmax to Vector and SparseVector
1 parent 79ec072 commit 3f7de7d

File tree

3 files changed

+95
-5
lines changed

3 files changed

+95
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

+52-5
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
150150
toDense
151151
}
152152
}
153+
154+
/**
155+
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
156+
* Returns -1 if vector has length 0.
157+
*/
158+
def argmax: Int
153159
}
154160

155161
/**
@@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
588594
new SparseVector(size, ii, vv)
589595
}
590596

591-
/**
592-
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
593-
* Returns -1 if vector has length 0.
594-
*/
595-
private[spark] def argmax: Int = {
597+
override def argmax: Int = {
596598
if (size == 0) {
597599
-1
598600
} else {
@@ -717,6 +719,51 @@ class SparseVector(
717719
new SparseVector(size, ii, vv)
718720
}
719721
}
722+
723+
override def argmax: Int = {
724+
if (size == 0) {
725+
-1
726+
} else {
727+
// Find the max active entry.
728+
var maxIdx = indices(0)
729+
var maxValue = values(0)
730+
var maxJ = 0
731+
var j = 1
732+
val na = numActives
733+
while (j < na) {
734+
val v = values(j)
735+
if (v > maxValue) {
736+
maxValue = v
737+
maxIdx = indices(j)
738+
maxJ = j
739+
}
740+
j += 1
741+
}
742+
743+
// If the max active entry is nonpositive and there exists inactive ones, find the first zero.
744+
if (maxValue <= 0.0 && na < size) {
745+
if (maxValue == 0.0) {
746+
// If there exists an inactive entry before maxIdx, find it and return its index.
747+
if (maxJ < maxIdx) {
748+
var k = 0
749+
while (k < maxJ && indices(k) == k) {
750+
k += 1
751+
}
752+
maxIdx = k
753+
}
754+
} else {
755+
// If the max active value is negative, find and return the first inactive index.
756+
var k = 0
757+
while (k < na && indices(k) == k) {
758+
k += 1
759+
}
760+
maxIdx = k
761+
}
762+
}
763+
764+
maxIdx
765+
}
766+
}
720767
}
721768

722769
object SparseVector {

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

+39
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,50 @@ class VectorsSuite extends SparkFunSuite with Logging {
6262
assert(vec.toArray.eq(arr))
6363
}
6464

65+
test("dense argmax") {
66+
val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
67+
assert(vec.argmax === -1)
68+
69+
val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
70+
assert(vec2.argmax === 3)
71+
72+
val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
73+
assert(vec3.argmax === 3)
74+
}
75+
6576
test("sparse to array") {
6677
val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
6778
assert(vec.toArray === arr)
6879
}
6980

81+
test("sparse argmax") {
82+
val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
83+
assert(vec.argmax === -1)
84+
85+
val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
86+
assert(vec2.argmax === 3)
87+
88+
val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
89+
assert(vec3.argmax === 2)
90+
91+
// check for case that sparse vector is created with
92+
// only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
93+
val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
94+
assert(vec4.argmax === 0)
95+
96+
val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
97+
assert(vec5.argmax === 1)
98+
99+
val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
100+
assert(vec6.argmax === 2)
101+
102+
val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
103+
assert(vec7.argmax === 1)
104+
105+
val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
106+
assert(vec8.argmax === 0)
107+
}
108+
70109
test("vector equals") {
71110
val dv1 = Vectors.dense(arr.clone())
72111
val dv2 = Vectors.dense(arr.clone())

project/MimaExcludes.scala

+4
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ object MimaExcludes {
9898
"org.apache.spark.api.r.StringRRDD.this"),
9999
ProblemFilters.exclude[MissingMethodProblem](
100100
"org.apache.spark.api.r.BaseRRDD.this")
101+
) ++ Seq(
102+
// SPARK-7422 add argmax for sparse vectors
103+
ProblemFilters.exclude[MissingMethodProblem](
104+
"org.apache.spark.mllib.linalg.Vector.argmax")
101105
)
102106

103107
case v if v.startsWith("1.4") =>

0 commit comments

Comments
 (0)