File tree Expand file tree Collapse file tree 2 files changed +19
-1
lines changed
main/scala/org/apache/spark/ml/util
test/scala/org/apache/spark/ml/clustering Expand file tree Collapse file tree 2 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -324,10 +324,16 @@ private[ml] object DefaultParamsWriter {
324
324
metadataJson
325
325
}
326
326
327
- def saveInitialModel [T <: HasInitialModel [_ <: MLWritable ]](instance : T , path : String ): Unit = {
327
+ def saveInitialModel [T <: HasInitialModel [_ <: MLWritable with Params ]](
328
+ instance : T , path : String ): Unit = {
328
329
if (instance.isDefined(instance.initialModel)) {
329
330
val initialModelPath = new Path (path, " initialModel" ).toString
330
331
val initialModel = instance.getOrDefault(instance.initialModel)
332
+ // When saving, only keep the direct initialModel by eliminating possible initialModels of the
333
+ // direct initialModel, to avoid unnecessary deep recursion of initialModel.
334
+ if (initialModel.hasParam(" initialModel" )) {
335
+ initialModel.clear(initialModel.getParam(" initialModel" ))
336
+ }
331
337
initialModel.save(initialModelPath)
332
338
}
333
339
}
Original file line number Diff line number Diff line change @@ -186,6 +186,18 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
186
186
// k is not ignored after initialModel is cleared
187
187
assert(kmeans.setK(k - 1 ).getK === k - 1 )
188
188
}
189
+
190
+ test(" Eliminate possible initialModels of the direct initialModel" ) {
191
+ val randomModel = KMeansSuite .generateRandomKMeansModel(dim, k)
192
+ val kmeans = new KMeans ().setK(k).setMaxIter(1 ).setInitialModel(randomModel)
193
+ val firstLevelModel = kmeans.fit(dataset)
194
+ val secondLevelModel = kmeans.setInitialModel(firstLevelModel).fit(dataset)
195
+ assert(secondLevelModel.getInitialModel
196
+ .isSet(secondLevelModel.getInitialModel.getParam(" initialModel" )))
197
+ val savedThenLoadedModel = testDefaultReadWrite(secondLevelModel, testParams = false )
198
+ assert(! savedThenLoadedModel.getInitialModel
199
+ .isSet(savedThenLoadedModel.getInitialModel.getParam(" initialModel" )))
200
+ }
189
201
}
190
202
191
203
object KMeansSuite {
You can’t perform that action at this time.
0 commit comments