-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-20082][ml] LDA incremental model learning #17461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d66929f
847e69c
73919ba
1c850ac
b36f7e9
73f6a75
6e6782c
f5b5d38
b5e50eb
5562259
43c63a7
8df8da0
d3a4f16
6fe3a20
31cd11b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.ml | ||
|
||
// scalastyle:off println | ||
// $example on$ | ||
import java.io.File | ||
|
||
import org.apache.spark.ml.clustering.{LDA, LDAModel, LocalLDAModel} | ||
import org.apache.spark.sql.DataFrame | ||
// $example off$ | ||
import org.apache.spark.sql.SparkSession | ||
|
||
/** | ||
* An example demonstrating incremental update of LDA, with setInitialModel parameter. | ||
* Run with | ||
* {{{ | ||
* bin/run-example ml.OnlineLDAIncrementalExample | ||
* }}} | ||
*/ | ||
object OnlineLDAIncrementalExample { | ||
|
||
def main(args: Array[String]): Unit = { | ||
|
||
val spark = SparkSession | ||
.builder() | ||
.appName(s"${this.getClass.getSimpleName}") | ||
.getOrCreate() | ||
|
||
import spark.implicits._ | ||
|
||
// $example on$ | ||
// Loads data. | ||
val dataset: DataFrame = spark.read.format("libsvm") | ||
.load("data/mllib/sample_lda_libsvm_data.txt") | ||
|
||
// --------------------------------------- | ||
// Build a LDA incrementally | ||
// - here we're simulating data coming incrementally, in chunks | ||
// - this assumes vocabulary is fixed (same words as columns of the matrix) | ||
|
||
val nbChunks = 3 | ||
val chunks = dataset.randomSplit(Array.fill(nbChunks)(1D / nbChunks), 7L) | ||
|
||
// LDA model params | ||
val k = 10 | ||
val iter = 30 | ||
|
||
// To pass a trained LDA from one iteration to the other, we persist it to a file | ||
val modelPath = File.createTempFile("./incrModel", null).getPath | ||
var previousModelPath: String = null | ||
|
||
var idx = 0 | ||
|
||
for (chunk <- chunks) { | ||
idx += 1 | ||
println(s"Incremental, chunk=$idx, k=$k, maxIterations=$iter") | ||
|
||
// Build LDA model as usual | ||
val lda = new LDA() | ||
.setK(k) | ||
.setMaxIter(iter) | ||
.setOptimizer("online") | ||
|
||
// and point to the previous model, when there's one | ||
if (previousModelPath != null) { | ||
lda.setInitialModel(previousModelPath) | ||
} | ||
|
||
val model = lda.fit(dataset) | ||
|
||
// Check perplexity at each iteration | ||
val lp = model.logPerplexity(dataset) | ||
println(s"Log Perplexity=$lp") | ||
println("---------------------------------") | ||
|
||
// persist for next chunk | ||
previousModelPath = s"$modelPath-$idx" | ||
model.save(previousModelPath) | ||
} | ||
|
||
val finalModel = LocalLDAModel.load(previousModelPath) | ||
|
||
// Describe topics. | ||
val topics = finalModel.describeTopics(3) | ||
println("The topics described by their top-weighted terms:") | ||
topics.show(false) | ||
|
||
// Shows the result. | ||
val transformed = finalModel.transform(dataset) | ||
transformed.show(false) | ||
// $example off$ | ||
|
||
spark.stop() | ||
} | ||
} | ||
|
||
// scalastyle:on println |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,8 @@ package org.apache.spark.ml.clustering | |
|
||
import java.util.Locale | ||
|
||
import scala.util.{Failure, Success, Try} | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.json4s.DefaultFormats | ||
import org.json4s.JsonAST.JObject | ||
|
@@ -32,10 +34,7 @@ import org.apache.spark.ml.param._ | |
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata | ||
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, | ||
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, | ||
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, | ||
OnlineLDAOptimizer => OldOnlineLDAOptimizer} | ||
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} | ||
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} | ||
import org.apache.spark.mllib.linalg.MatrixImplicits._ | ||
import org.apache.spark.mllib.linalg.VectorImplicits._ | ||
|
@@ -180,6 +179,29 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM | |
@Since("1.6.0") | ||
def getOptimizer: String = $(optimizer) | ||
|
||
|
||
/** | ||
* For Online optimizer only (currently): [[optimizer]] = "online". | ||
|
||
* An initial model to be used as a starting point for the learning, instead of a random | ||
* initialization. Provide the path to a serialized trained LDAModel. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LDAModel => LocalLDAModel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @hhbyyh |
||
* | ||
* @group param | ||
*/ | ||
@Since("2.3.0") | ||
final val initialModel = new Param[String](this, "initialModel", "Path to a serialized " + | ||
"LDAModel to use as a starting point, instead of a random initilization. Only" + | ||
"supported by online model.", (value: String) => validateInitialModel(value)) | ||
|
||
/** @group getParam */ | ||
@Since("2.3.0") | ||
def getInitialModel : String = $(initialModel) | ||
|
||
protected def validateInitialModel(value: String): Boolean = { | ||
Try(LocalLDAModel.load(value)).isSuccess | ||
} | ||
|
||
|
||
/** | ||
* Output column with estimates of the topic mixture distribution for each document (often called | ||
* "theta" in the literature). Returns a vector of zeros for an empty document. | ||
|
@@ -345,6 +367,10 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM | |
s" must be >= 1. Found value: $getTopicConcentration") | ||
} | ||
} | ||
if (isSet(initialModel)) { | ||
require(getOptimizer.toLowerCase(Locale.ROOT) == "online", "initialModel is currently " + | ||
"supported only by Online LDA Optimizer") | ||
} | ||
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) | ||
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) | ||
} | ||
|
@@ -863,6 +889,10 @@ class LDA @Since("1.6.0") ( | |
@Since("1.6.0") | ||
def setOptimizer(value: String): this.type = set(optimizer, value) | ||
|
||
/** @group setParam */ | ||
@Since("2.3.0") | ||
def setInitialModel(value: String): this.type = set(initialModel, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.6.0") | ||
def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value) | ||
|
@@ -897,7 +927,7 @@ class LDA @Since("1.6.0") ( | |
val instr = Instrumentation.create(this, dataset) | ||
instr.logParams(featuresCol, topicDistributionCol, k, maxIter, subsamplingRate, | ||
checkpointInterval, keepLastCheckpoint, optimizeDocConcentration, topicConcentration, | ||
learningDecay, optimizer, learningOffset, seed) | ||
learningDecay, optimizer, learningOffset, seed, initialModel) | ||
|
||
val oldLDA = new OldLDA() | ||
.setK($(k)) | ||
|
@@ -907,6 +937,12 @@ class LDA @Since("1.6.0") ( | |
.setSeed($(seed)) | ||
.setCheckpointInterval($(checkpointInterval)) | ||
.setOptimizer(getOldOptimizer) | ||
|
||
if (isSet(initialModel)) { | ||
val init = LocalLDAModel.load($(initialModel)) | ||
oldLDA.setInitialModel(init.oldLocalModel) | ||
} | ||
|
||
// TODO: persist here, or in old LDA? | ||
val oldData = LDA.getOldDataset(dataset, $(featuresCol)) | ||
val oldModel = oldLDA.run(oldData) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better follow the original format