Skip to content

[SPARK-20082][ml] LDA incremental model learning #17461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
3 changes: 3 additions & 0 deletions docs/mllib-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ configuration), this parameter specifies the frequency with which
checkpoints will be created. If `maxIterations` is large, using
checkpointing can help reduce shuffle file sizes on disk and help with
failure recovery.
* `initialModel`: this parameter, only supported by `OnlineLDAOptimizer`,
specifies a previously trained LocalLDAModel as a start point instead of
a random initialization. This can be used for incremental learning.


All of `spark.mllib`'s LDA models support:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml

// scalastyle:off println
// $example on$
import java.io.File

import org.apache.spark.ml.clustering.{LDA, LDAModel, LocalLDAModel}
import org.apache.spark.sql.DataFrame
// $example off$
import org.apache.spark.sql.SparkSession

/**
* An example demonstrating incremental update of LDA, with setInitialModel parameter.
* Run with
* {{{
* bin/run-example ml.OnlineLDAIncrementalExample
* }}}
*/
object OnlineLDAIncrementalExample {

def main(args: Array[String]): Unit = {

val spark = SparkSession
.builder()
.appName(s"${this.getClass.getSimpleName}")
.getOrCreate()

import spark.implicits._

// $example on$
// Loads data.
val dataset: DataFrame = spark.read.format("libsvm")
.load("data/mllib/sample_lda_libsvm_data.txt")

// ---------------------------------------
// Build a LDA incrementally
// - here we're simulating data coming incrementally, in chunks
// - this assumes vocabulary is fixed (same words as columns of the matrix)

val nbChunks = 3
val chunks = dataset.randomSplit(Array.fill(nbChunks)(1D / nbChunks), 7L)

// LDA model params
val k = 10
val iter = 30

// To pass a trained LDA from one iteration to the other, we persist it to a file
val modelPath = File.createTempFile("./incrModel", null).getPath
var previousModelPath: String = null

var idx = 0

for (chunk <- chunks) {
idx += 1
println(s"Incremental, chunk=$idx, k=$k, maxIterations=$iter")

// Build LDA model as usual
val lda = new LDA()
.setK(k)
.setMaxIter(iter)
.setOptimizer("online")

// and point to the previous model, when there's one
if (previousModelPath != null) {
lda.setInitialModel(previousModelPath)
}

val model = lda.fit(dataset)

// Check perplexity at each iteration
val lp = model.logPerplexity(dataset)
println(s"Log Perplexity=$lp")
println("---------------------------------")

// persist for next chunk
previousModelPath = s"$modelPath-$idx"
model.save(previousModelPath)
}

val finalModel = LocalLDAModel.load(previousModelPath)

// Describe topics.
val topics = finalModel.describeTopics(3)
println("The topics described by their top-weighted terms:")
topics.show(false)

// Shows the result.
val transformed = finalModel.transform(dataset)
transformed.show(false)
// $example off$

spark.stop()
}
}

// scalastyle:on println
46 changes: 41 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.ml.clustering

import java.util.Locale

import scala.util.{Failure, Success, Try}

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JObject
Expand All @@ -32,10 +34,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better follow the original format

import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
Expand Down Expand Up @@ -180,6 +179,29 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
@Since("1.6.0")
def getOptimizer: String = $(optimizer)


/**
* For Online optimizer only (currently): [[optimizer]] = "online".

* An initial model to be used as a starting point for the learning, instead of a random
* initialization. Provide the path to a serialized trained LDAModel.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LDAModel => LocalLDAModel

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hhbyyh
After a quick look, it seems ok to follow this. I'll try to merge locally and see how it fits. I'll keep you updated. Do you think #18610 will be merged shortly ? (cc @yanboliang)

*
* @group param
*/
@Since("2.3.0")
final val initialModel = new Param[String](this, "initialModel", "Path to a serialized " +
"LDAModel to use as a starting point, instead of a random initilization. Only" +
"supported by online model.", (value: String) => validateInitialModel(value))

/** @group getParam */
@Since("2.3.0")
def getInitialModel : String = $(initialModel)

protected def validateInitialModel(value: String): Boolean = {
Try(LocalLDAModel.load(value)).isSuccess
}


/**
* Output column with estimates of the topic mixture distribution for each document (often called
* "theta" in the literature). Returns a vector of zeros for an empty document.
Expand Down Expand Up @@ -345,6 +367,10 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" must be >= 1. Found value: $getTopicConcentration")
}
}
if (isSet(initialModel)) {
require(getOptimizer.toLowerCase(Locale.ROOT) == "online", "initialModel is currently " +
"supported only by Online LDA Optimizer")
}
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}
Expand Down Expand Up @@ -863,6 +889,10 @@ class LDA @Since("1.6.0") (
@Since("1.6.0")
def setOptimizer(value: String): this.type = set(optimizer, value)

/** @group setParam */
@Since("2.3.0")
def setInitialModel(value: String): this.type = set(initialModel, value)

/** @group setParam */
@Since("1.6.0")
def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value)
Expand Down Expand Up @@ -897,7 +927,7 @@ class LDA @Since("1.6.0") (
val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, topicDistributionCol, k, maxIter, subsamplingRate,
checkpointInterval, keepLastCheckpoint, optimizeDocConcentration, topicConcentration,
learningDecay, optimizer, learningOffset, seed)
learningDecay, optimizer, learningOffset, seed, initialModel)

val oldLDA = new OldLDA()
.setK($(k))
Expand All @@ -907,6 +937,12 @@ class LDA @Since("1.6.0") (
.setSeed($(seed))
.setCheckpointInterval($(checkpointInterval))
.setOptimizer(getOldOptimizer)

if (isSet(initialModel)) {
val init = LocalLDAModel.load($(initialModel))
oldLDA.setInitialModel(init.oldLocalModel)
}

// TODO: persist here, or in old LDA?
val oldData = LDA.getOldDataset(dataset, $(featuresCol))
val oldModel = oldLDA.run(oldData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ private[python] class LDAModelWrapper(model: LDAModel) {
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)

def getModel: LDAModel = model
}
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,8 @@ private[python] class PythonMLLibAPI extends Serializable {
topicConcentration: Double,
seed: java.lang.Long,
checkpointInterval: Int,
optimizer: String): LDAModelWrapper = {
optimizer: String,
initialModel: LDAModelWrapper): LDAModelWrapper = {
val algo = new LDA()
.setK(k)
.setMaxIterations(maxIterations)
Expand All @@ -542,6 +543,13 @@ private[python] class PythonMLLibAPI extends Serializable {

if (seed != null) algo.setSeed(seed)

if (initialModel != null) {
if (optimizer != "online") {
throw new IllegalArgumentException("initialModel is only supported by online optimizer.")
}
algo.setInitialModel(initialModel.getModel)
}

val documents = data.rdd.map(_.asScala.toArray).map { r =>
r(0) match {
case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
Expand Down
24 changes: 22 additions & 2 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ class LDA private (
private var topicConcentration: Double,
private var seed: Long,
private var checkpointInterval: Int,
private var ldaOptimizer: LDAOptimizer) extends Logging {
private var ldaOptimizer: LDAOptimizer,
private var initialModel: Option[LDAModel]) extends Logging {

/**
* Constructs a LDA instance with default parameters.
*/
@Since("1.3.0")
def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1),
topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10,
ldaOptimizer = new EMLDAOptimizer)
ldaOptimizer = new EMLDAOptimizer, initialModel = None)

/**
* Number of topics to infer, i.e., the number of soft cluster centers.
Expand Down Expand Up @@ -317,6 +318,25 @@ class LDA private (
this
}


/**
* Returns the initial model that has been provided, if any
*/
@Since("2.3.0")
def getInitialModel: Option[LDAModel] = this.initialModel

/**
* Set the initial starting point, bypassing the random initialization.
* This can be used for incremental learning.
* This is supported only for online optimizer. Models must have the same parameters
* (k, vocabulary size, topic concentration)
*/
@Since("2.3.0")
def setInitialModel(model: LDAModel): this.type = {
initialModel = Some(model)
this
}

/**
* Learn an LDA model using the given dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ trait LDAOptimizer {

/**
* Initializer for the optimizer. LDA passes the common parameters to the optimizer and
* the internal structure can be initialized properly.
* the internal structure can be initialized properly
*/
private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer

Expand Down Expand Up @@ -126,6 +126,9 @@ final class EMLDAOptimizer extends LDAOptimizer {
val topicConcentration = lda.getTopicConcentration
val k = lda.getK

require(lda.getInitialModel.isEmpty,
"Only online optimizer supports initialization with a previous model.")

// Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
// but values in (0,1) are not yet supported.
require(docConcentration > 1.0 || docConcentration == -1.0, s"LDA docConcentration must be" +
Expand Down Expand Up @@ -246,7 +249,6 @@ final class EMLDAOptimizer extends LDAOptimizer {
}
}


/**
* :: DeveloperApi ::
*
Expand Down Expand Up @@ -416,30 +418,54 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
this.k = lda.getK
this.corpusSize = docs.count()
this.vocabSize = docs.first()._2.size
this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) {
if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))

this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
this.randomGenerator = new Random(lda.getSeed)

lda.getInitialModel match {
case Some(initModel: LocalLDAModel) =>
require(initModel.k == this.k,
"Mismatched number of topics with provided initial model")
require(initModel.vocabSize == docs.first()._2.size,
"Mismatched vocabulary size with provided initial model")
require(initModel.topicConcentration == this.eta,
"Mismatched topic concentration with provided initial model")
require(initModel.gammaShape == this.gammaShape,
"Mismatched gamma shape with provided initial model")
// get alpha from previously trained model
this.alpha = initModel.docConcentration
// Initialize the variational distribution from the initial model
this.lambda = initModel.topicsMatrix.transpose.asBreeze.toDenseMatrix
case None =>
this.alpha = initAlpha(lda.getAsymmetricDocConcentration)
// Initialize the variational distribution q(beta|lambda) randomly
this.lambda = getGammaMatrix(k, vocabSize)
case Some(other) =>
throw new IllegalArgumentException(
s"Local (online) LDA model expected for initial model, got $other")
}

this.docs = docs
this.iteration = 0
this
}

private def initAlpha(asymmetricDocConcentration: Vector) = {
if (asymmetricDocConcentration.size == 1) {
if (asymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
else {
require(lda.getAsymmetricDocConcentration(0) >= 0,
require(asymmetricDocConcentration(0) >= 0,
s"all entries in alpha must be >=0, got: $alpha")
Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0)))
Vectors.dense(Array.fill(k)(asymmetricDocConcentration(0)))
}
} else {
require(lda.getAsymmetricDocConcentration.size == k,
require(asymmetricDocConcentration.size == k,
s"alpha must have length k, got: $alpha")
lda.getAsymmetricDocConcentration.foreachActive { case (_, x) =>
asymmetricDocConcentration.foreachActive { case (_, x) =>
require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
}
lda.getAsymmetricDocConcentration
asymmetricDocConcentration
}
this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
this.randomGenerator = new Random(lda.getSeed)

this.docs = docs

// Initialize the variational distribution q(beta|lambda)
this.lambda = getGammaMatrix(k, vocabSize)
this.iteration = 0
this
}

override private[clustering] def next(): OnlineLDAOptimizer = {
Expand Down
Loading