Skip to content

Commit

Permalink
[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output
Browse files Browse the repository at this point in the history
The temporary column should be dropped after we get the prediction column. harsha2010

Author: Xiangrui Meng <meng@databricks.com>

Closes apache#6592 from mengxr/SPARK-8049 and squashes the following commits:

1d89107 [Xiangrui Meng] use SparkFunSuite
6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output
  • Loading branch information
mengxr committed Jun 2, 2015
1 parent 605ddbb commit 89f21f6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
.drop(accColName)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}

test("SPARK-8049: OneVsRest shouldn't output temp columns") {
val logReg = new LogisticRegression()
.setMaxIter(1)
val ovr = new OneVsRest()
.setClassifier(logReg)
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}
}

private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
Expand Down

0 comments on commit 89f21f6

Please sign in to comment.