-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-10780][ML] Add an initial model to kmeans #11119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cc13c1e
36b1729
125ac76
658c4c9
abfe0e2
b87e07e
9a4a55e
65f4237
166a6ff
f56e443
f3f9226
08afa4c
31f7b94
7c1c8f7
95627b5
a0629b5
b0eb111
58bf1cf
b7856e1
9f5e698
914d319
c40192b
6526c08
23a78d6
d4f59d9
47f182b
78ed9a1
03575bf
c21ffa2
eb7fbbe
f6e024a
92cf83d
95bf12f
7fc6918
5fbb132
4bba7c1
127ca06
0e93fda
261fcfa
b3ea01a
47de1fe
2c9cd51
e529972
939ebe5
7046913
8516a2c
6f169eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,10 @@ | |
|
||
package org.apache.spark.ml.clustering | ||
|
||
import scala.util.{Failure, Success} | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.apache.hadoop.mapred.InvalidInputException | ||
|
||
import org.apache.spark.SparkException | ||
import org.apache.spark.annotation.{Experimental, Since} | ||
|
@@ -35,7 +38,25 @@ import org.apache.spark.sql.functions.{col, udf} | |
import org.apache.spark.sql.types.{IntegerType, StructType} | ||
|
||
/** | ||
* Common params for KMeans and KMeansModel | ||
* Params for KMeans | ||
*/ | ||
|
||
private[clustering] trait KMeansInitialModelParams extends HasInitialModel[KMeansModel] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we follow the convention in ALS, then we should have |
||
/** | ||
* Param for KMeansModel to use for warm start. | ||
* Whenever initialModel is set: | ||
* 1. the initialModel k will override the param k; | ||
* 2. the param initMode is set to initialModel and manually set is ignored; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
nit: Let's just remove the punctuation from the numbered list |
||
* 3. other params are untouched. | ||
* @group param | ||
*/ | ||
final val initialModel: Param[KMeansModel] = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
new Param[KMeansModel](this, "initialModel", "A KMeansModel for warm start.") | ||
|
||
} | ||
|
||
/** | ||
* Params for KMeansModel | ||
*/ | ||
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol | ||
with HasSeed with HasPredictionCol with HasTol { | ||
|
@@ -58,6 +79,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |
* Param for the initialization algorithm. This can be either "random" to choose random points as | ||
* initial cluster centers, or "k-means||" to use a parallel variant of k-means++ | ||
* (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. | ||
* The param initMode will be ignored if the param initialModel is set. | ||
* @group expertParam | ||
*/ | ||
@Since("1.5.0") | ||
|
@@ -82,6 +104,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |
@Since("1.5.0") | ||
def getInitSteps: Int = $(initSteps) | ||
|
||
|
||
/** | ||
* Validates and transforms the input schema. | ||
* @param schema input schema | ||
|
@@ -103,7 +126,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |
@Experimental | ||
class KMeansModel private[ml] ( | ||
@Since("1.5.0") override val uid: String, | ||
private val parentModel: MLlibKMeansModel) | ||
private[ml] val parentModel: MLlibKMeansModel) | ||
extends Model[KMeansModel] with KMeansParams with MLWritable { | ||
|
||
@Since("1.5.0") | ||
|
@@ -124,7 +147,8 @@ class KMeansModel private[ml] ( | |
@Since("2.0.0") | ||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
transformSchema(dataset.schema, logging = true) | ||
val predictUDF = udf((vector: Vector) => predict(vector)) | ||
val tmpParent: MLlibKMeansModel = parentModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a comment would be useful? |
||
val predictUDF = udf((vector: Vector) => tmpParent.predict(vector)) | ||
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||
} | ||
|
||
|
@@ -133,8 +157,6 @@ class KMeansModel private[ml] ( | |
validateAndTransformSchema(schema) | ||
} | ||
|
||
private[clustering] def predict(features: Vector): Int = parentModel.predict(features) | ||
|
||
@Since("2.0.0") | ||
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) | ||
|
||
|
@@ -210,6 +232,7 @@ object KMeansModel extends MLReadable[KMeansModel] { | |
override protected def saveImpl(path: String): Unit = { | ||
// Save metadata and Params | ||
DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
|
||
// Save model data: cluster centers | ||
val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => | ||
Data(idx, center) | ||
|
@@ -244,6 +267,7 @@ object KMeansModel extends MLReadable[KMeansModel] { | |
} | ||
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) | ||
DefaultParamsReader.getAndSetParams(model, metadata) | ||
|
||
model | ||
} | ||
} | ||
|
@@ -259,7 +283,8 @@ object KMeansModel extends MLReadable[KMeansModel] { | |
@Experimental | ||
class KMeans @Since("1.5.0") ( | ||
@Since("1.5.0") override val uid: String) | ||
extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { | ||
extends Estimator[KMeansModel] | ||
with KMeansParams with KMeansInitialModelParams with DefaultParamsWritable { | ||
|
||
setDefault( | ||
k -> 2, | ||
|
@@ -284,11 +309,26 @@ class KMeans @Since("1.5.0") ( | |
|
||
/** @group setParam */ | ||
@Since("1.5.0") | ||
def setK(value: Int): this.type = set(k, value) | ||
def setK(value: Int): this.type = { | ||
if (isSet(initialModel)) { | ||
logWarning("initialModel is set, so k will be ignored. Clear initialModel first.") | ||
this | ||
} else { | ||
set(k, value) | ||
} | ||
} | ||
|
||
/** @group expertSetParam */ | ||
@Since("1.5.0") | ||
def setInitMode(value: String): this.type = set(initMode, value) | ||
def setInitMode(value: String): this.type = { | ||
if (isSet(initialModel)) { | ||
logWarning(s"initialModel is set, so initMode will be ignored. Clear initialModel first.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We say it will be ignored, but then still set it below. |
||
} | ||
if (value == MLlibKMeans.K_MEANS_INITIAL_MODEL) { | ||
logWarning(s"initMode of $value is not supported here, please use setInitialModel.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the discussion, I think we decided to throw an error for |
||
} | ||
set(initMode, value) | ||
} | ||
|
||
/** @group expertSetParam */ | ||
@Since("1.5.0") | ||
|
@@ -306,6 +346,25 @@ class KMeans @Since("1.5.0") ( | |
@Since("1.5.0") | ||
def setSeed(value: Long): this.type = set(seed, value) | ||
|
||
/** @group setParam */ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Above, we may want to also log warning in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, so I think myself and Nick were under the impression that we would simply ignore any calls to def setK(value: Int): this.type = {
if (isSet(initialModel)) {
logWarning("initialModel is set, so k will be ignored. Clear initialModel first.")
this
} else {
set(k, value)
}
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with @sethah on this - initial model should take precedence - essentially ignoring any There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My personally preference is throwing an exception to make it clear for users; but I don't have strong opinion about this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't feel really strongly about it. But I think it is not a critical failure, so does not really require a hard stop of an exception. Since we make it clear in the doc that initial model |
||
@Since("2.1.0") | ||
def setInitialModel(value: KMeansModel): this.type = { | ||
val kOfInitialModel = value.parentModel.clusterCenters.length | ||
if (isSet(k)) { | ||
if ($(k) != kOfInitialModel) { | ||
val previousK = $(k) | ||
set(k, kOfInitialModel) | ||
logWarning(s"Param K is set to $kOfInitialModel by the initialModel." + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous value is lost because you set it before logging this warning. Need to reorder the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Maybe |
||
s" Previous value is $previousK.") | ||
} | ||
} else { | ||
set(k, kOfInitialModel) | ||
logWarning(s"Param K is set to $kOfInitialModel by the initialModel.") | ||
} | ||
set(initMode, "initialModel") | ||
set(initialModel, value) | ||
} | ||
|
||
@Since("2.0.0") | ||
override def fit(dataset: Dataset[_]): KMeansModel = { | ||
transformSchema(dataset.schema, logging = true) | ||
|
@@ -323,6 +382,24 @@ class KMeans @Since("1.5.0") ( | |
.setMaxIterations($(maxIter)) | ||
.setSeed($(seed)) | ||
.setEpsilon($(tol)) | ||
|
||
if (isDefined(initialModel)) { | ||
// Check that the feature dimensions are equal | ||
val dimOfData = rdd.first().size | ||
val dimOfInitialModel = $(initialModel).clusterCenters.head.size | ||
require(dimOfData == dimOfInitialModel, | ||
s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.") | ||
|
||
// Check that the number of clusters are equal | ||
val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length | ||
if (kOfInitialModel != $(k)) { | ||
logWarning(s"mismatched cluster count, ${$(k)} cluster centers required but" + | ||
s" $kOfInitialModel found in the initial model.") | ||
} | ||
|
||
algo.setInitialModel($(initialModel).parentModel) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does MLlibKMeans check if the dimensionalities match? Also, why do you set it to parentModel? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
|
||
val parentModel = algo.run(rdd, Option(instr)) | ||
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) | ||
val summary = new KMeansSummary( | ||
|
@@ -336,13 +413,48 @@ class KMeans @Since("1.5.0") ( | |
override def transformSchema(schema: StructType): StructType = { | ||
validateAndTransformSchema(schema) | ||
} | ||
|
||
@Since("2.1.0") | ||
override def write: MLWriter = new KMeans.KMeansWriter(this) | ||
} | ||
|
||
@Since("1.6.0") | ||
object KMeans extends DefaultParamsReadable[KMeans] { | ||
|
||
// TODO: [SPARK-17784]: Add a fromCenters method | ||
|
||
@Since("1.6.0") | ||
override def load(path: String): KMeans = super.load(path) | ||
|
||
@Since("1.6.0") | ||
override def read: MLReader[KMeans] = new KMeansReader | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has the same signature as before, right? If so, then the Since version can be 1.6.0. |
||
|
||
/** [[MLWriter]] instance for [[KMeans]] */ | ||
private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter { | ||
override protected def saveImpl(path: String): Unit = { | ||
DefaultParamsWriter.saveInitialModel(instance, path) | ||
DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
} | ||
} | ||
|
||
private class KMeansReader extends MLReader[KMeans] { | ||
|
||
/** Checked against metadata when loading estimator */ | ||
private val className = classOf[KMeans].getName | ||
|
||
override def load(path: String): KMeans = { | ||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
val instance = new KMeans(metadata.uid) | ||
|
||
DefaultParamsReader.getAndSetParams(instance, metadata) | ||
DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { | ||
case Success(v) => instance.setInitialModel(v) | ||
case Failure(_: InvalidInputException) => // initialModel doesn't exist, do nothing | ||
case Failure(e) => throw e | ||
} | ||
instance | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the filename should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed others. Will handle this soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dbtsai Sorry for the long delay. I'm not so sure of your mean by "integrate into the other part of Params", what's it for? |
||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.param.shared | ||
|
||
import org.apache.spark.ml.Model | ||
import org.apache.spark.ml.param._ | ||
|
||
private[ml] trait HasInitialModel[T <: Model[T]] extends Params { | ||
|
||
def initialModel: Param[T] | ||
|
||
/** @group getParam */ | ||
final def getInitialModel: T = $(initialModel) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,8 @@ package org.apache.spark.ml.util | |
|
||
import java.io.IOException | ||
|
||
import scala.util.Try | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.json4s._ | ||
import org.json4s.{DefaultFormats, JObject} | ||
|
@@ -32,6 +34,7 @@ import org.apache.spark.ml._ | |
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} | ||
import org.apache.spark.ml.feature.RFormulaModel | ||
import org.apache.spark.ml.param.{ParamPair, Params} | ||
import org.apache.spark.ml.param.shared.HasInitialModel | ||
import org.apache.spark.ml.tuning.ValidatorParams | ||
import org.apache.spark.sql.{SparkSession, SQLContext} | ||
import org.apache.spark.util.Utils | ||
|
@@ -300,7 +303,8 @@ private[ml] object DefaultParamsWriter { | |
paramMap: Option[JValue] = None): String = { | ||
val uid = instance.uid | ||
val cls = instance.getClass.getName | ||
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] | ||
val params = instance.extractParamMap().toSeq | ||
.filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it not possible to check if the param is an instance of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also possible that we introduce other params that can be of type model in the future. Thus, being an instance of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough |
||
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => | ||
p.name -> parse(p.jsonEncode(v)) | ||
}.toList)) | ||
|
@@ -309,6 +313,7 @@ private[ml] object DefaultParamsWriter { | |
("sparkVersion" -> sc.version) ~ | ||
("uid" -> uid) ~ | ||
("paramMap" -> jsonParams) | ||
|
||
val metadata = extraMetadata match { | ||
case Some(jObject) => | ||
basicMetadata ~ jObject | ||
|
@@ -318,6 +323,20 @@ private[ml] object DefaultParamsWriter { | |
val metadataJson: String = compact(render(metadata)) | ||
metadataJson | ||
} | ||
|
||
def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]]( | ||
instance: T, path: String): Unit = { | ||
if (instance.isDefined(instance.initialModel)) { | ||
val initialModelPath = new Path(path, "initialModel").toString | ||
val initialModel = instance.getOrDefault(instance.initialModel) | ||
// When saving, only keep the direct initialModel by eliminating possible initialModels of the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick That's what I mean: When saving, we eliminate possible initialModels of the direct initialModel. |
||
// direct initialModel, to avoid unnecessary deep recursion of initialModel. | ||
if (initialModel.hasParam("initialModel")) { | ||
initialModel.clear(initialModel.getParam("initialModel")) | ||
} | ||
initialModel.save(initialModelPath) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -446,6 +465,11 @@ private[ml] object DefaultParamsReader { | |
val cls = Utils.classForName(metadata.className) | ||
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) | ||
} | ||
|
||
def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Try[M] = { | ||
val initialModelPath = new Path(path, "initialModel").toString | ||
Try(loadParamsInstance[M](initialModelPath, sc)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note, now we always try to load the initial model, which I would guess adds some overhead in the cases when there is none, but we have to fail and catch the exception. Let's leave this for now since it's cleaner, and only change it if it becomes a problem. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the overhead is that large, and if there is no model then it should fail fast as the directory won't exist. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, this is not a critical path, so I think the overhead is okay. |
||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -414,6 +414,8 @@ object KMeans { | |
val RANDOM = "random" | ||
@Since("0.8.0") | ||
val K_MEANS_PARALLEL = "k-means||" | ||
@Since("2.1.0") | ||
val K_MEANS_INITIAL_MODEL = "initialModel" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it need to be public? This only serves a purpose when used with ML I think. |
||
|
||
/** | ||
* Trains a k-means model using the given set of parameters. | ||
|
@@ -589,6 +591,7 @@ object KMeans { | |
initMode match { | ||
case KMeans.RANDOM => true | ||
case KMeans.K_MEANS_PARALLEL => true | ||
case KMeans.K_MEANS_INITIAL_MODEL => true | ||
case _ => false | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please revert these