Skip to content

Commit 3a995da

Browse files
committed
Added stats from cross validation as a val in the cross validation model to save them for user access
1 parent ad06727 commit 3a995da

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
135135
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
136136
logInfo(s"Best cross-validation metric: $bestMetric.")
137137
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
138-
copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
138+
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
139139
}
140140

141141
override def transformSchema(schema: StructType): StructType = {
@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
158158
@Experimental
159159
class CrossValidatorModel private[ml] (
160160
override val uid: String,
161-
val bestModel: Model[_])
161+
val bestModel: Model[_],
162+
val crossValidationMetrics: Array[Double])
162163
extends Model[CrossValidatorModel] with CrossValidatorParams {
163164

164165
override def validateParams(): Unit = {
@@ -175,7 +176,7 @@ class CrossValidatorModel private[ml] (
175176
}
176177

177178
override def copy(extra: ParamMap): CrossValidatorModel = {
178-
val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
179+
val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]], crossValidationMetrics.clone())
179180
copyValues(copied, extra)
180181
}
181182
}

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
5656
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
5757
assert(parent.getRegParam === 0.001)
5858
assert(parent.getMaxIter === 10)
59+
assert(cvModel.crossValidationMetrics.length == 4)
5960
}
6061

6162
test("validateParams should check estimatorParamMaps") {

0 commit comments

Comments
 (0)