Skip to content

Commit 6f59eed

Browse files
committed
update libSVMFile to determine number of features automatically
1 parent 3432e84 commit 6f59eed

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,37 @@ class MLContext(self: SparkContext) {
3434
* where the feature indices are converted to zero-based.
3535
*
3636
* @param path file or directory path in any Hadoop-supported file system URI
37-
* @param numFeatures number of features
38-
* @param labelParser parser for labels, default: _.toDouble
37+
* @param numFeatures number of features, it will be determined from input
38+
* if a non-positive value is given
39+
*@param labelParser parser for labels, default: _.toDouble
3940
* @return labeled data stored as an RDD[LabeledPoint]
4041
*/
4142
def libSVMFile(
4243
path: String,
4344
numFeatures: Int,
4445
labelParser: String => Double = _.toDouble): RDD[LabeledPoint] = {
45-
self.textFile(path).map(_.trim).filter(!_.isEmpty).map { line =>
46-
val items = line.split(' ')
46+
val parsed = self.textFile(path).map(_.trim).filter(!_.isEmpty).map(_.split(' '))
47+
// Determine number of features.
48+
val d = if (numFeatures > 0) {
49+
numFeatures
50+
} else {
51+
parsed.map { items =>
52+
if (items.length > 1) {
53+
items.last.split(':')(0).toInt
54+
} else {
55+
0
56+
}
57+
}.reduce(math.max)
58+
}
59+
parsed.map { items =>
4760
val label = labelParser(items.head)
48-
val features = Vectors.sparse(numFeatures, items.tail.map { item =>
61+
val (indices, values) = items.tail.map { item =>
4962
val indexAndValue = item.split(':')
5063
val index = indexAndValue(0).toInt - 1
5164
val value = indexAndValue(1).toDouble
5265
(index, value)
53-
})
54-
LabeledPoint(label, features)
66+
}.unzip
67+
LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray))
5568
}
5669
}
5770
}

mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,26 @@ class MLContextSuite extends FunSuite with LocalSparkContext {
3333
val lines =
3434
"""
3535
|1 1:1.0 3:2.0 5:3.0
36+
|0
3637
|0 2:4.0 4:5.0 6:6.0
3738
""".stripMargin
3839
val tempDir = Files.createTempDir()
3940
val file = new File(tempDir.getPath, "part-00000")
4041
Files.write(lines, file, Charsets.US_ASCII)
41-
val points = sc.libSVMFile(tempDir.toURI.toString, 6).collect()
42-
assert(points.length === 2)
43-
assert(points(0).label === 1.0)
44-
assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
45-
assert(points(1).label === 0.0)
46-
assert(points(1).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
42+
43+
val pointsWithNumFeatures = sc.libSVMFile(tempDir.toURI.toString, 6).collect()
44+
val pointsWithoutNumFeatures = sc.libSVMFile(tempDir.toURI.toString, 0).collect()
45+
46+
for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
47+
assert(points.length === 3)
48+
assert(points(0).label === 1.0)
49+
assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
50+
assert(points(1).label == 0.0)
51+
assert(points(1).features == Vectors.sparse(6, Seq()))
52+
assert(points(2).label === 0.0)
53+
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
54+
}
55+
4756
try {
4857
file.delete()
4958
tempDir.delete()

0 commit comments

Comments
 (0)