Skip to content

Commit 127dec5

Browse files
committed
update argmax impl
1 parent 2ea6a55 commit 127dec5

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -724,22 +724,44 @@ class SparseVector(
724724
if (size == 0) {
725725
-1
726726
} else {
727+
// Find the max active entry.
727728
var maxIdx = indices(0)
728729
var maxValue = values(0)
729-
730-
foreachActive { (i, v) =>
730+
var maxJ = 0
731+
var j = 1
732+
val na = numActives
733+
while (j < na) {
734+
val v = values(j)
731735
if (v > maxValue) {
732-
maxIdx = i
733736
maxValue = v
737+
maxIdx = indices(j)
738+
maxJ = j
734739
}
740+
j += 1
735741
}
736742

737-
var k = 0
738-
while (k < indices.length && indices(k) == k && values(k) != 0.0) {
739-
k += 1
743+
// If the max active entry is nonpositive and there exists inactive ones, find the first zero.
744+
if (maxValue <= 0.0 && na < size) {
745+
if (maxValue == 0.0) {
746+
// If there exists an inactive entry before maxIdx, find it and return its index.
747+
if (maxJ < maxIdx) {
748+
var k = 0
749+
while (k < maxJ && indices(k) == k) {
750+
k += 1
751+
}
752+
maxIdx = k
753+
}
754+
} else {
755+
// If the max active value is negative, find and return the first inactive index.
756+
var k = 0
757+
while (k < na && indices(k) == k) {
758+
k += 1
759+
}
760+
maxIdx = k
761+
}
740762
}
741763

742-
if (maxValue <= 0.0 || k >= maxIdx) k else maxIdx
764+
maxIdx
743765
}
744766
}
745767
}

0 commit comments

Comments
 (0)