Skip to content

Commit 581c623

Browse files
committed
seperate API and adjust batch split
1 parent 37af91a commit 581c623

File tree

5 files changed

+65
-69
lines changed

5 files changed

+65
-69
lines changed

examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
5858
corpus.cache();
5959

6060
// Cluster the documents into three topics using LDA
61-
DistributedLDAModel ldaModel = (DistributedLDAModel) new LDA().setK(3).run(corpus);
61+
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
6262

6363
// Output topics. Each is a distribution over words (matching word count vectors)
6464
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()

examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ object LDAExample {
159159
}
160160
println()
161161
}
162-
162+
sc.stop()
163163
}
164164

165165
/**

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 61 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
3232
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices}
3333
import org.apache.spark.rdd.RDD
3434
import org.apache.spark.util.Utils
35-
import org.apache.spark.mllib.rdd.RDDFunctions._
3635

3736

3837
/**
@@ -223,10 +222,6 @@ class LDA private (
223222
this
224223
}
225224

226-
object LDAMode extends Enumeration {
227-
val EM, Online = Value
228-
}
229-
230225
/**
231226
* Learn an LDA model using the given dataset.
232227
*
@@ -236,37 +231,30 @@ class LDA private (
236231
* Document IDs must be unique and >= 0.
237232
* @return Inferred LDA model
238233
*/
239-
def run(documents: RDD[(Long, Vector)], mode: LDAMode.Value = LDAMode.EM ): LDAModel = {
240-
mode match {
241-
case LDAMode.EM =>
242-
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
243-
checkpointInterval)
244-
var iter = 0
245-
val iterationTimes = Array.fill[Double](maxIterations)(0)
246-
while (iter < maxIterations) {
247-
val start = System.nanoTime()
248-
state.next()
249-
val elapsedSeconds = (System.nanoTime() - start) / 1e9
250-
iterationTimes(iter) = elapsedSeconds
251-
iter += 1
252-
}
253-
state.graphCheckpointer.deleteAllCheckpoints()
254-
new DistributedLDAModel(state, iterationTimes)
255-
case LDAMode.Online =>
256-
val vocabSize = documents.first._2.size
257-
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, vocabSize)
258-
var iter = 0
259-
while (iter < onlineLDA.batchNumber) {
260-
onlineLDA.next()
261-
iter += 1
262-
}
263-
new LocalLDAModel(Matrices.fromBreeze(onlineLDA._lambda).transpose)
264-
case _ => throw new IllegalArgumentException(s"Do not support mode $mode.")
234+
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
235+
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
236+
checkpointInterval)
237+
var iter = 0
238+
val iterationTimes = Array.fill[Double](maxIterations)(0)
239+
while (iter < maxIterations) {
240+
val start = System.nanoTime()
241+
state.next()
242+
val elapsedSeconds = (System.nanoTime() - start) / 1e9
243+
iterationTimes(iter) = elapsedSeconds
244+
iter += 1
265245
}
246+
state.graphCheckpointer.deleteAllCheckpoints()
247+
new DistributedLDAModel(state, iterationTimes)
248+
}
249+
250+
def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = {
251+
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k)
252+
(0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
253+
new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
266254
}
267255

268256
/** Java-friendly version of [[run()]] */
269-
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
257+
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
270258
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
271259
}
272260
}
@@ -418,58 +406,66 @@ private[clustering] object LDA {
418406

419407
}
420408

421-
// todo: add reference to paper and Hoffman
409+
/**
410+
* Optimizer for Online LDA algorithm which breaks corpus into mini-batches and scans only once.
411+
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
412+
*/
422413
private[clustering] class OnlineLDAOptimizer(
423-
val documents: RDD[(Long, Vector)],
424-
val k: Int,
425-
val vocabSize: Int) extends Serializable{
414+
private val documents: RDD[(Long, Vector)],
415+
private val k: Int) extends Serializable{
426416

427-
private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
428-
private val tau0 = 1024 // down weights early iterations
429-
private val D = documents.count()
417+
private val vocabSize = documents.first._2.size
418+
private val D = documents.count().toInt
430419
private val batchSize = if (D / 1000 > 4096) 4096
431420
else if (D / 1000 < 4) 4
432421
else D / 1000
433-
val batchNumber = (D/batchSize + 1).toInt
434-
private val batches = documents.sliding(batchNumber).collect()
422+
val batchNumber = D/batchSize
435423

436424
// Initialize the variational distribution q(beta|lambda)
437-
var _lambda = getGammaMatrix(k, vocabSize) // K * V
438-
private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
439-
private var _expElogbeta = exp(_Elogbeta) // K * V
425+
var lambda = getGammaMatrix(k, vocabSize) // K * V
426+
private var Elogbeta = dirichlet_expectation(lambda) // K * V
427+
private var expElogbeta = exp(Elogbeta) // K * V
440428

441-
private var batchCount = 0
429+
private var batchId = 0
442430
def next(): Unit = {
443-
// weight of the mini-batch.
444-
val rhot = math.pow(tau0 + batchCount, -kappa)
431+
require(batchId < batchNumber)
432+
// weight of the mini-batch. 1024 down weights early iterations
433+
val weight = math.pow(1024 + batchId, -0.5)
434+
val batch = documents.filter(doc => doc._1 % batchNumber == batchId)
445435

436+
// Given a mini-batch of documents, estimates the parameters gamma controlling the
437+
// variational distribution over the topic weights for each document in the mini-batch.
446438
var stat = BDM.zeros[Double](k, vocabSize)
447-
stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)
448-
449-
stat = stat :* _expElogbeta
450-
_lambda = _lambda * (1 - rhot) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * rhot
451-
_Elogbeta = dirichlet_expectation(_lambda)
452-
_expElogbeta = exp(_Elogbeta)
453-
batchCount += 1
439+
stat = batch.aggregate(stat)(seqOp, _ += _)
440+
stat = stat :* expElogbeta
441+
442+
// Update lambda based on documents.
443+
lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight
444+
Elogbeta = dirichlet_expectation(lambda)
445+
expElogbeta = exp(Elogbeta)
446+
batchId += 1
454447
}
455448

456-
private def seqOp(other: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
449+
// for each document d update that document's gamma and phi
450+
private def seqOp(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
457451
val termCounts = doc._2
458452
val (ids, cts) = termCounts match {
459453
case v: DenseVector => (((0 until v.size).toList), v.values)
460454
case v: SparseVector => (v.indices.toList, v.values)
461455
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
462456
}
463457

458+
// Initialize the variational distribution q(theta|gamma) for the mini-batch
464459
var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K
465460
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
466461
var expElogthetad = exp(Elogthetad.t).t // 1 * K
467-
val expElogbetad = _expElogbeta(::, ids).toDenseMatrix // K * ids
462+
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
468463

469464
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
470465
var meanchange = 1D
471-
val ctsVector = new BDV[Double](cts).t // 1 * ids
466+
val ctsVector = new BDV[Double](cts).t // 1 * ids
472467

468+
// Iterate between gamma and phi until convergence
473469
while (meanchange > 1e-6) {
474470
val lastgamma = gammad
475471
// 1*K 1 * ids ids * k
@@ -480,30 +476,30 @@ private[clustering] object LDA {
480476
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
481477
}
482478

483-
val v1 = expElogthetad.t.toDenseMatrix.t
484-
val v2 = (ctsVector / phinorm).t.toDenseMatrix
485-
val outerResult = kron(v1, v2) // K * ids
479+
val m1 = expElogthetad.t.toDenseMatrix.t
480+
val m2 = (ctsVector / phinorm).t.toDenseMatrix
481+
val outerResult = kron(m1, m2) // K * ids
486482
for (i <- 0 until ids.size) {
487-
other(::, ids(i)) := (other(::, ids(i)) + outerResult(::, i))
483+
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
488484
}
489-
other
485+
stat
490486
}
491487

492-
def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
488+
private def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
493489
val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0)
494490
val temp = gammaRandomGenerator.sample(row * col).toArray
495491
(new BDM[Double](col, row, temp)).t
496492
}
497493

498-
def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
494+
private def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
499495
val rowSum = sum(alpha(breeze.linalg.*, ::))
500496
val digAlpha = digamma(alpha)
501497
val digRowSum = digamma(rowSum)
502498
val result = digAlpha(::, breeze.linalg.*) - digRowSum
503499
result
504500
}
505501

506-
def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
502+
private def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
507503
digamma(v) - digamma(sum(v))
508504
}
509505
}

mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public void distributedLDAModel() {
8888
.setMaxIterations(5)
8989
.setSeed(12345);
9090

91-
DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
91+
DistributedLDAModel model = lda.run(corpus);
9292

9393
// Check: basic parameters
9494
LocalLDAModel localModel = model.toLocal();

mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
6868
.setSeed(12345)
6969
val corpus = sc.parallelize(tinyCorpus, 2)
7070

71-
val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
71+
val model: DistributedLDAModel = lda.run(corpus)
7272

7373
// Check: basic parameters
7474
val localModel = model.toLocal

0 commit comments

Comments
 (0)