@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
135
135
logInfo(s " Best set of parameters: \n ${epm(bestIndex)}" )
136
136
logInfo(s " Best cross-validation metric: $bestMetric. " )
137
137
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf [Model [_]]
138
- copyValues(new CrossValidatorModel (uid, bestModel).setParent(this ))
138
+ copyValues(new CrossValidatorModel (uid, bestModel, metrics ).setParent(this ))
139
139
}
140
140
141
141
override def transformSchema (schema : StructType ): StructType = {
@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
158
158
@ Experimental
159
159
class CrossValidatorModel private [ml] (
160
160
override val uid : String ,
161
- val bestModel : Model [_])
161
+ val bestModel : Model [_],
162
+ val crossValidationMetrics : Array [Double ])
162
163
extends Model [CrossValidatorModel ] with CrossValidatorParams {
163
164
164
165
override def validateParams (): Unit = {
@@ -175,7 +176,7 @@ class CrossValidatorModel private[ml] (
175
176
}
176
177
177
178
override def copy (extra : ParamMap ): CrossValidatorModel = {
178
- val copied = new CrossValidatorModel (uid, bestModel.copy(extra).asInstanceOf [Model [_]])
179
+ val copied = new CrossValidatorModel (uid, bestModel.copy(extra).asInstanceOf [Model [_]], crossValidationMetrics.clone() )
179
180
copyValues(copied, extra)
180
181
}
181
182
}
0 commit comments