@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
32
32
import org .apache .spark .annotation .Since
33
33
import org .apache .spark .ml ._
34
34
import org .apache .spark .ml .attribute ._
35
- import org .apache .spark .ml .linalg .Vector
35
+ import org .apache .spark .ml .linalg .{ Vector , Vectors }
36
36
import org .apache .spark .ml .param .{Param , ParamMap , ParamPair , Params }
37
37
import org .apache .spark .ml .param .shared .{HasParallelism , HasWeightCol }
38
38
import org .apache .spark .ml .util ._
@@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
55
55
/**
56
56
* Params for [[OneVsRest ]].
57
57
*/
58
- private [ml] trait OneVsRestParams extends PredictorParams
58
+ private [ml] trait OneVsRestParams extends ClassifierParams
59
59
with ClassifierTypeTrait with HasWeightCol {
60
60
61
61
/**
@@ -138,6 +138,12 @@ final class OneVsRestModel private[ml] (
138
138
@ Since (" 1.4.0" ) val models : Array [_ <: ClassificationModel [_, _]])
139
139
extends Model [OneVsRestModel ] with OneVsRestParams with MLWritable {
140
140
141
+ @ Since (" 2.4.0" )
142
+ val numClasses : Int = models.length
143
+
144
+ @ Since (" 2.4.0" )
145
+ val numFeatures : Int = models.head.numFeatures
146
+
141
147
/** @group setParam */
142
148
@ Since (" 2.1.0" )
143
149
def setFeaturesCol (value : String ): this .type = set(featuresCol, value)
@@ -146,6 +152,10 @@ final class OneVsRestModel private[ml] (
146
152
@ Since (" 2.1.0" )
147
153
def setPredictionCol (value : String ): this .type = set(predictionCol, value)
148
154
155
+ /** @group setParam */
156
+ @ Since (" 2.4.0" )
157
+ def setRawPredictionCol (value : String ): this .type = set(rawPredictionCol, value)
158
+
149
159
@ Since (" 1.4.0" )
150
160
override def transformSchema (schema : StructType ): StructType = {
151
161
validateAndTransformSchema(schema, fitting = false , getClassifier.featuresDataType)
@@ -195,14 +205,18 @@ final class OneVsRestModel private[ml] (
195
205
newDataset.unpersist()
196
206
}
197
207
198
- // output the index of the classifier with highest confidence as prediction
199
- val labelUDF = udf { (predictions : Map [Int , Double ]) =>
200
- predictions.maxBy(_._2)._1.toDouble
208
+ // output the RawPrediction as vector
209
+ val rawPredictionUDF = udf { (predictions : Map [Int , Double ]) =>
210
+ Vectors .sparse(numClasses, predictions.toList )
201
211
}
202
212
203
- // output label and label metadata as prediction
213
+ // output the index of the classifier with highest confidence as prediction
214
+ val labelUDF = udf { (predictions : Vector ) => predictions.argmax.toDouble }
215
+
216
+ // output confidence as rwa prediction, label and label metadata as prediction
204
217
aggregatedDataset
205
- .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
218
+ .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
219
+ .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
206
220
.drop(accColName)
207
221
}
208
222
0 commit comments