Skip to content

Commit 6e4f8ca

Browse files
committed
add check for ascending order
1 parent 9956365 commit 6e4f8ca

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,19 @@ object MLUtils {
8282
val value = indexAndValue(1).toDouble
8383
(index, value)
8484
}.unzip
85-
require(indices.size == 0 || indices(0) >= 0,
86-
"indices should be one-based in LIBSVM format")
85+
86+
// check if indices is one-based and in ascending order
87+
var previous = -1
88+
var i = 0
89+
val indicesLength = indices.size
90+
while (i < indicesLength) {
91+
if (indices(i) <= previous) {
92+
throw new IllegalArgumentException("indices should be one-based and in ascending order")
93+
}
94+
previous = indices(i)
95+
i += 1
96+
}
97+
8798
(label, indices.toArray, values.toArray)
8899
}
89100

mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
110110
Utils.deleteRecursively(tempDir)
111111
}
112112

113-
test("loadLibSVMFile throws SparkException when passing a zero-based vector") {
113+
test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
114114
val lines =
115115
"""
116116
|0
@@ -122,7 +122,24 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
122122
val path = tempDir.toURI.toString
123123

124124
intercept[SparkException] {
125-
val pointsWithoutNumFeatures = loadLibSVMFile(sc, path).collect()
125+
loadLibSVMFile(sc, path).collect()
126+
}
127+
Utils.deleteRecursively(tempDir)
128+
}
129+
130+
test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
131+
val lines =
132+
"""
133+
|0
134+
|0 3:4.0 2:5.0 6:6.0
135+
""".stripMargin
136+
val tempDir = Utils.createTempDir()
137+
val file = new File(tempDir.getPath, "part-00000")
138+
Files.write(lines, file, Charsets.US_ASCII)
139+
val path = tempDir.toURI.toString
140+
141+
intercept[SparkException] {
142+
loadLibSVMFile(sc, path).collect()
126143
}
127144
Utils.deleteRecursively(tempDir)
128145
}

0 commit comments

Comments
 (0)