Skip to content

Commit 0cfc20a

Browse files
committed
add rawPrediction as an output column;
add numCLasses and numFeatures to OneVsRestModel
1 parent 252468a commit 0cfc20a

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
3232
import org.apache.spark.annotation.Since
3333
import org.apache.spark.ml._
3434
import org.apache.spark.ml.attribute._
35-
import org.apache.spark.ml.linalg.Vector
35+
import org.apache.spark.ml.linalg.{Vector, Vectors}
3636
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
3737
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
3838
import org.apache.spark.ml.util._
@@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
5555
/**
5656
* Params for [[OneVsRest]].
5757
*/
58-
private[ml] trait OneVsRestParams extends PredictorParams
58+
private[ml] trait OneVsRestParams extends ClassifierParams
5959
with ClassifierTypeTrait with HasWeightCol {
6060

6161
/**
@@ -138,6 +138,12 @@ final class OneVsRestModel private[ml] (
138138
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
139139
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
140140

141+
@Since("2.4.0")
142+
val numClasses: Int = models.length
143+
144+
@Since("2.4.0")
145+
val numFeatures: Int = models.head.numFeatures
146+
141147
/** @group setParam */
142148
@Since("2.1.0")
143149
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -146,6 +152,10 @@ final class OneVsRestModel private[ml] (
146152
@Since("2.1.0")
147153
def setPredictionCol(value: String): this.type = set(predictionCol, value)
148154

155+
/** @group setParam */
156+
@Since("2.4.0")
157+
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
158+
149159
@Since("1.4.0")
150160
override def transformSchema(schema: StructType): StructType = {
151161
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@@ -195,14 +205,18 @@ final class OneVsRestModel private[ml] (
195205
newDataset.unpersist()
196206
}
197207

198-
// output the index of the classifier with highest confidence as prediction
199-
val labelUDF = udf { (predictions: Map[Int, Double]) =>
200-
predictions.maxBy(_._2)._1.toDouble
208+
// output the RawPrediction as vector
209+
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
210+
Vectors.sparse(numClasses, predictions.toList )
201211
}
202212

203-
// output label and label metadata as prediction
213+
// output the index of the classifier with highest confidence as prediction
214+
val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble }
215+
216+
// output confidence as rwa prediction, label and label metadata as prediction
204217
aggregatedDataset
205-
.withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
218+
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
219+
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
206220
.drop(accColName)
207221
}
208222

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
7272
.setClassifier(new LogisticRegression)
7373
assert(ova.getLabelCol === "label")
7474
assert(ova.getPredictionCol === "prediction")
75+
assert(ova.getRawPredictionCol === "rawPrediction")
7576
val ovaModel = ova.fit(dataset)
7677

7778
MLTestingUtils.checkCopyAndUids(ova, ovaModel)
7879

79-
assert(ovaModel.models.length === numClasses)
80+
assert(ovaModel.numClasses === numClasses)
8081
val transformedDataset = ovaModel.transform(dataset)
8182

8283
// check for label metadata in prediction col
@@ -179,9 +180,10 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
179180
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
180181
ovaModel.setFeaturesCol("fea")
181182
ovaModel.setPredictionCol("pred")
183+
ovaModel.setRawPredictionCol("rawpred")
182184
val transformedDataset = ovaModel.transform(dataset2)
183185
val outputFields = transformedDataset.schema.fieldNames.toSet
184-
assert(outputFields === Set("y", "fea", "pred"))
186+
assert(outputFields === Set("y", "fea", "pred", "rawpred"))
185187
}
186188

187189
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
@@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
190192
val ovr = new OneVsRest()
191193
.setClassifier(logReg)
192194
val output = ovr.fit(dataset).transform(dataset)
193-
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
195+
assert(output.schema.fieldNames.toSet
196+
=== Set("label", "features", "prediction", "rawPrediction"))
194197
}
195198

196199
test("SPARK-21306: OneVsRest should support setWeightCol") {

0 commit comments

Comments
 (0)