Skip to content

[SPARK-10780][ML] Add an initial model to kmeans #11119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cc13c1e
add initial model to kmeans
yinxusen Feb 8, 2016
36b1729
add two setters for initial model
yinxusen Feb 8, 2016
125ac76
revert to previous codegen and add a separate sharedParams for genera…
yinxusen Feb 10, 2016
658c4c9
add model check
yinxusen Feb 10, 2016
abfe0e2
add more testsuite
yinxusen Feb 10, 2016
b87e07e
add new save/load to kmeans
yinxusen Feb 10, 2016
9a4a55e
add new model save/load for KMeansModel
yinxusen Feb 10, 2016
65f4237
fix side effect
yinxusen Feb 11, 2016
166a6ff
add hashcode and equals
yinxusen Feb 11, 2016
f56e443
merge with master
yinxusen Mar 7, 2016
f3f9226
add }
yinxusen Mar 7, 2016
08afa4c
filter initialModel out
yinxusen Mar 9, 2016
31f7b94
new equal
yinxusen Mar 10, 2016
7c1c8f7
reinse test
yinxusen Mar 14, 2016
95627b5
refine KMeans
yinxusen Mar 15, 2016
a0629b5
merge with master
yinxusen Apr 15, 2016
b0eb111
merge with master
yinxusen Apr 22, 2016
58bf1cf
change back to DefaultParamsWritable/Readable
yinxusen Apr 22, 2016
b7856e1
add initialmodel metadata to default read write
yinxusen Apr 28, 2016
9f5e698
add save/load for initial model
yinxusen Apr 28, 2016
914d319
remove validateParams
yinxusen Apr 28, 2016
c40192b
remove useless DefaultFormats
yinxusen Apr 28, 2016
6526c08
merge with master
yinxusen Aug 29, 2016
23a78d6
fix vector issue
yinxusen Aug 31, 2016
d4f59d9
multi fixes
yinxusen Sep 2, 2016
47f182b
fix not set initialmodel
yinxusen Sep 2, 2016
78ed9a1
refine tests
yinxusen Sep 12, 2016
03575bf
remove some setters
yinxusen Sep 14, 2016
c21ffa2
change the implementation of initialModel
yinxusen Sep 14, 2016
eb7fbbe
remove hashcode and equal check
yinxusen Sep 16, 2016
f6e024a
fix errors
yinxusen Sep 16, 2016
92cf83d
fix errors
yinxusen Sep 27, 2016
95bf12f
add TODO with JIRA
yinxusen Oct 5, 2016
7fc6918
add infering K from initial model
yinxusen Oct 6, 2016
5fbb132
add more assert
yinxusen Oct 7, 2016
4bba7c1
fix logics of test
yinxusen Oct 14, 2016
127ca06
add new test of using different initial model
yinxusen Oct 14, 2016
0e93fda
get rid of metadata.hasInitialModel
yinxusen Oct 14, 2016
261fcfa
fix nits and errors
yinxusen Oct 17, 2016
b3ea01a
fix small errors
yinxusen Oct 17, 2016
47de1fe
fix load model excpetion
yinxusen Oct 18, 2016
2c9cd51
add more comments
yinxusen Oct 18, 2016
e529972
eliminate possible initialModels of the direct initialModel
yinxusen Oct 19, 2016
939ebe5
merge with master
yinxusen Nov 8, 2016
7046913
first update KMeans
yinxusen Nov 8, 2016
8516a2c
update test
yinxusen Nov 8, 2016
6f169eb
fix mima test
yinxusen Nov 8, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 120 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.ml.clustering

import scala.util.{Failure, Success}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please revert these


import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.InvalidInputException

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
Expand All @@ -35,7 +38,25 @@ import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}

/**
* Common params for KMeans and KMeansModel
* Params for KMeans
*/

private[clustering] trait KMeansInitialModelParams extends HasInitialModel[KMeansModel] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we follow the convention in ALS, then we should have KMeansModelParams and KMeansParams extends KMeansModelParams with .... I think it would be good to do the same here.

/**
* Param for KMeansModel to use for warm start.
* Whenever initialModel is set:
* 1. the initialModel k will override the param k;
* 2. the param initMode is set to initialModel and manually set is ignored;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. the param initMode is set to "initialModel" and manually setting initMode will be ignored

nit: Let's just remove the punctuation from the numbered list

* 3. other params are untouched.
* @group param
*/
final val initialModel: Param[KMeansModel] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override final val

new Param[KMeansModel](this, "initialModel", "A KMeansModel for warm start.")

}

/**
* Params for KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol {
Expand All @@ -58,6 +79,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* Param for the initialization algorithm. This can be either "random" to choose random points as
* initial cluster centers, or "k-means||" to use a parallel variant of k-means++
* (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
* The param initMode will be ignored if the param initialModel is set.
* @group expertParam
*/
@Since("1.5.0")
Expand All @@ -82,6 +104,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
def getInitSteps: Int = $(initSteps)


/**
* Validates and transforms the input schema.
* @param schema input schema
Expand All @@ -103,7 +126,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Experimental
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel)
private[ml] val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {

@Since("1.5.0")
Expand All @@ -124,7 +147,8 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
val tmpParent: MLlibKMeansModel = parentModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment would be useful? // avoid encapsulating the entire model in the closure

val predictUDF = udf((vector: Vector) => tmpParent.predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

Expand All @@ -133,8 +157,6 @@ class KMeansModel private[ml] (
validateAndTransformSchema(schema)
}

private[clustering] def predict(features: Vector): Int = parentModel.predict(features)

@Since("2.0.0")
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)

Expand Down Expand Up @@ -210,6 +232,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)

// Save model data: cluster centers
val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
Data(idx, center)
Expand Down Expand Up @@ -244,6 +267,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
}
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)

model
}
}
Expand All @@ -259,7 +283,8 @@ object KMeansModel extends MLReadable[KMeansModel] {
@Experimental
class KMeans @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
extends Estimator[KMeansModel]
with KMeansParams with KMeansInitialModelParams with DefaultParamsWritable {

setDefault(
k -> 2,
Expand All @@ -284,11 +309,26 @@ class KMeans @Since("1.5.0") (

/** @group setParam */
@Since("1.5.0")
def setK(value: Int): this.type = set(k, value)
def setK(value: Int): this.type = {
if (isSet(initialModel)) {
logWarning("initialModel is set, so k will be ignored. Clear initialModel first.")
this
} else {
set(k, value)
}
}

/** @group expertSetParam */
@Since("1.5.0")
def setInitMode(value: String): this.type = set(initMode, value)
def setInitMode(value: String): this.type = {
if (isSet(initialModel)) {
logWarning(s"initialModel is set, so initMode will be ignored. Clear initialModel first.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We say it will be ignored, but then still set it below.

}
if (value == MLlibKMeans.K_MEANS_INITIAL_MODEL) {
logWarning(s"initMode of $value is not supported here, please use setInitialModel.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the discussion, I think we decided to throw an error for setInitMode("initialModel") if initialModel wasn't already set. If initialModel has been set, then we'd just update the initMode as normal.

}
set(initMode, value)
}

/** @group expertSetParam */
@Since("1.5.0")
Expand All @@ -306,6 +346,25 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above, we may want to also log warning in setK if initialModel has already been set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so I think myself and Nick were under the impression that we would simply ignore any calls to setK if initialModel is set. So,

  def setK(value: Int): this.type = {
    if (isSet(initialModel)) {
      logWarning("initialModel is set, so k will be ignored. Clear initialModel first.")
      this
    } else {
      set(k, value)
    }
  }

@MLnick @dbtsai does that seem correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @sethah on this - initial model should take precedence - essentially ignoring any setK call whether before or after setting initial model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My personally preference is throwing an exception to make it clear for users; but I don't have strong opinion about this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel really strongly about it. But I think it is not a critical failure, so does not really require a hard stop of an exception. Since we make it clear in the doc that initial model k takes precedence, if a user sets a different k anyway we can still proceed, just let them know that we are ignoring their setting.

@Since("2.1.0")
def setInitialModel(value: KMeansModel): this.type = {
val kOfInitialModel = value.parentModel.clusterCenters.length
if (isSet(k)) {
if ($(k) != kOfInitialModel) {
val previousK = $(k)
set(k, kOfInitialModel)
logWarning(s"Param K is set to $kOfInitialModel by the initialModel." +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous value is lost because you set it before logging this warning. Need to reorder the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe s"Param k was changed from $previousK to $kOfInitialModel to match the initialModel"

s" Previous value is $previousK.")
}
} else {
set(k, kOfInitialModel)
logWarning(s"Param K is set to $kOfInitialModel by the initialModel.")
}
set(initMode, "initialModel")
set(initialModel, value)
}

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -323,6 +382,24 @@ class KMeans @Since("1.5.0") (
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))

if (isDefined(initialModel)) {
// Check that the feature dimensions are equal
val dimOfData = rdd.first().size
val dimOfInitialModel = $(initialModel).clusterCenters.head.size
require(dimOfData == dimOfInitialModel,
s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.")

// Check that the number of clusters are equal
val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length
if (kOfInitialModel != $(k)) {
logWarning(s"mismatched cluster count, ${$(k)} cluster centers required but" +
s" $kOfInitialModel found in the initial model.")
}

algo.setInitialModel($(initialModel).parentModel)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does MLlibKMeans check if the dimensionalities match? Also, why do you set it to parentModel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. MLlibKMeans doesn't check dimension match. If we want to check the dim, there are two ways, either use rdd.first().size() to get the dim, or we extract dim from attributes of DataFrame. The latter one is better but it need support from Attribute.
  2. The parentModel of a KMeansModel is a MLlibKMeansModel, which supports cluster centers for KMeans. And algo is a MLlibKMeans. You can refer to the definition of KMeanModel in ML package. In fact, the name of parentModel is misleading.

}

val parentModel = algo.run(rdd, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
Expand All @@ -336,13 +413,48 @@ class KMeans @Since("1.5.0") (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

@Since("2.1.0")
override def write: MLWriter = new KMeans.KMeansWriter(this)
}

@Since("1.6.0")
object KMeans extends DefaultParamsReadable[KMeans] {

// TODO: [SPARK-17784]: Add a fromCenters method

@Since("1.6.0")
override def load(path: String): KMeans = super.load(path)

@Since("1.6.0")
override def read: MLReader[KMeans] = new KMeansReader
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has the same signature as before, right? If so, then the Since version can be 1.6.0.


/** [[MLWriter]] instance for [[KMeans]] */
private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveInitialModel(instance, path)
DefaultParamsWriter.saveMetadata(instance, path, sc)
}
}

private class KMeansReader extends MLReader[KMeans] {

/** Checked against metadata when loading estimator */
private val className = classOf[KMeans].getName

override def load(path: String): KMeans = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val instance = new KMeans(metadata.uid)

DefaultParamsReader.getAndSetParams(instance, metadata)
DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match {
case Success(v) => instance.setInitialModel(v)
case Failure(_: InvalidInputException) => // initialModel doesn't exist, do nothing
case Failure(e) => throw e
}
instance
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the filename should be GenericTypeParams, right? Will be nice to integrate into the other part of Params. Maybe we can use scala macro?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed others. Will handle this soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbtsai Sorry for the long delay. I'm not so sure of your mean by "integrate into the other part of Params", what's it for?

* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.param.shared

import org.apache.spark.ml.Model
import org.apache.spark.ml.param._

private[ml] trait HasInitialModel[T <: Model[T]] extends Params {

def initialModel: Param[T]

/** @group getParam */
final def getInitialModel: T = $(initialModel)
}
26 changes: 25 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.ml.util

import java.io.IOException

import scala.util.Try

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.{DefaultFormats, JObject}
Expand All @@ -32,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.param.shared.HasInitialModel
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -300,7 +303,8 @@ private[ml] object DefaultParamsWriter {
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val params = instance.extractParamMap().toSeq
.filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]]
Copy link
Contributor

@MLnick MLnick Oct 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to check if the param is an instance of Param[Model[_]] rather than the name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried _.isInsance[Param[Model[_]]] and it failed because of the erasing of Model[_] in Java runtime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also possible that we introduce other params that can be of type model in the future. Thus, being an instance of Param[Mode[_]] may not be sufficient to determine if it's the initialModel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
Expand All @@ -309,6 +313,7 @@ private[ml] object DefaultParamsWriter {
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)

val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
Expand All @@ -318,6 +323,20 @@ private[ml] object DefaultParamsWriter {
val metadataJson: String = compact(render(metadata))
metadataJson
}

def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]](
instance: T, path: String): Unit = {
if (instance.isDefined(instance.initialModel)) {
val initialModelPath = new Path(path, "initialModel").toString
val initialModel = instance.getOrDefault(instance.initialModel)
// When saving, only keep the direct initialModel by eliminating possible initialModels of the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MLnick That's what I mean: When saving, we eliminate possible initialModels of the direct initialModel.

// direct initialModel, to avoid unnecessary deep recursion of initialModel.
if (initialModel.hasParam("initialModel")) {
initialModel.clear(initialModel.getParam("initialModel"))
}
initialModel.save(initialModelPath)
}
}
}

/**
Expand Down Expand Up @@ -446,6 +465,11 @@ private[ml] object DefaultParamsReader {
val cls = Utils.classForName(metadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}

def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Try[M] = {
val initialModelPath = new Path(path, "initialModel").toString
Try(loadParamsInstance[M](initialModelPath, sc))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, now we always try to load the initial model, which I would guess adds some overhead in the cases when there is none, but we have to fail and catch the exception. Let's leave this for now since it's cleaner, and only change it if it becomes a problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the overhead is that large, and if there is no model then it should fail fast as the directory won't exist.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this is not a critical path, so I think the overhead is okay.

}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ object KMeans {
val RANDOM = "random"
@Since("0.8.0")
val K_MEANS_PARALLEL = "k-means||"
@Since("2.1.0")
val K_MEANS_INITIAL_MODEL = "initialModel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be public? This only serves a purpose when used with ML I think.


/**
* Trains a k-means model using the given set of parameters.
Expand Down Expand Up @@ -589,6 +591,7 @@ object KMeans {
initMode match {
case KMeans.RANDOM => true
case KMeans.K_MEANS_PARALLEL => true
case KMeans.K_MEANS_INITIAL_MODEL => true
case _ => false
}
}
Expand Down
Loading