Skip to content

Commit e529972

Browse files
committed
eliminate possible initialModels of the direct initialModel
1 parent 2c9cd51 commit e529972

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,16 @@ private[ml] object DefaultParamsWriter {
324324
metadataJson
325325
}
326326

327-
def saveInitialModel[T <: HasInitialModel[_ <: MLWritable]](instance: T, path: String): Unit = {
327+
def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]](
328+
instance: T, path: String): Unit = {
328329
if (instance.isDefined(instance.initialModel)) {
329330
val initialModelPath = new Path(path, "initialModel").toString
330331
val initialModel = instance.getOrDefault(instance.initialModel)
332+
// When saving, only keep the direct initialModel by eliminating possible initialModels of the
333+
// direct initialModel, to avoid unnecessary deep recursion of initialModel.
334+
if (initialModel.hasParam("initialModel")) {
335+
initialModel.clear(initialModel.getParam("initialModel"))
336+
}
331337
initialModel.save(initialModelPath)
332338
}
333339
}

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,18 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
186186
// k is not ignored after initialModel is cleared
187187
assert(kmeans.setK(k - 1).getK === k - 1)
188188
}
189+
190+
test("Eliminate possible initialModels of the direct initialModel") {
191+
val randomModel = KMeansSuite.generateRandomKMeansModel(dim, k)
192+
val kmeans = new KMeans().setK(k).setMaxIter(1).setInitialModel(randomModel)
193+
val firstLevelModel = kmeans.fit(dataset)
194+
val secondLevelModel = kmeans.setInitialModel(firstLevelModel).fit(dataset)
195+
assert(secondLevelModel.getInitialModel
196+
.isSet(secondLevelModel.getInitialModel.getParam("initialModel")))
197+
val savedThenLoadedModel = testDefaultReadWrite(secondLevelModel, testParams = false)
198+
assert(!savedThenLoadedModel.getInitialModel
199+
.isSet(savedThenLoadedModel.getInitialModel.getParam("initialModel")))
200+
}
189201
}
190202

191203
object KMeansSuite {

0 commit comments

Comments
 (0)