@@ -32,7 +32,6 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
32
32
import org .apache .spark .mllib .linalg .{Vector , DenseVector , SparseVector , Matrices }
33
33
import org .apache .spark .rdd .RDD
34
34
import org .apache .spark .util .Utils
35
- import org .apache .spark .mllib .rdd .RDDFunctions ._
36
35
37
36
38
37
/**
@@ -223,10 +222,6 @@ class LDA private (
223
222
this
224
223
}
225
224
226
- object LDAMode extends Enumeration {
227
- val EM, Online = Value
228
- }
229
-
230
225
/**
231
226
* Learn an LDA model using the given dataset.
232
227
*
@@ -236,37 +231,30 @@ class LDA private (
236
231
* Document IDs must be unique and >= 0.
237
232
* @return Inferred LDA model
238
233
*/
239
- def run (documents : RDD [(Long , Vector )], mode : LDAMode .Value = LDAMode .EM ): LDAModel = {
240
- mode match {
241
- case LDAMode .EM =>
242
- val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
243
- checkpointInterval)
244
- var iter = 0
245
- val iterationTimes = Array .fill[Double ](maxIterations)(0 )
246
- while (iter < maxIterations) {
247
- val start = System .nanoTime()
248
- state.next()
249
- val elapsedSeconds = (System .nanoTime() - start) / 1e9
250
- iterationTimes(iter) = elapsedSeconds
251
- iter += 1
252
- }
253
- state.graphCheckpointer.deleteAllCheckpoints()
254
- new DistributedLDAModel (state, iterationTimes)
255
- case LDAMode .Online =>
256
- val vocabSize = documents.first._2.size
257
- val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, vocabSize)
258
- var iter = 0
259
- while (iter < onlineLDA.batchNumber) {
260
- onlineLDA.next()
261
- iter += 1
262
- }
263
- new LocalLDAModel (Matrices .fromBreeze(onlineLDA._lambda).transpose)
264
- case _ => throw new IllegalArgumentException (s " Do not support mode $mode. " )
234
+ def run (documents : RDD [(Long , Vector )]): DistributedLDAModel = {
235
+ val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
236
+ checkpointInterval)
237
+ var iter = 0
238
+ val iterationTimes = Array .fill[Double ](maxIterations)(0 )
239
+ while (iter < maxIterations) {
240
+ val start = System .nanoTime()
241
+ state.next()
242
+ val elapsedSeconds = (System .nanoTime() - start) / 1e9
243
+ iterationTimes(iter) = elapsedSeconds
244
+ iter += 1
265
245
}
246
+ state.graphCheckpointer.deleteAllCheckpoints()
247
+ new DistributedLDAModel (state, iterationTimes)
248
+ }
249
+
250
+ def runOnlineLDA (documents : RDD [(Long , Vector )]): LDAModel = {
251
+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k)
252
+ (0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
253
+ new LocalLDAModel (Matrices .fromBreeze(onlineLDA.lambda).transpose)
266
254
}
267
255
268
256
/** Java-friendly version of [[run() ]] */
269
- def run (documents : JavaPairRDD [java.lang.Long , Vector ]): LDAModel = {
257
+ def run (documents : JavaPairRDD [java.lang.Long , Vector ]): DistributedLDAModel = {
270
258
run(documents.rdd.asInstanceOf [RDD [(Long , Vector )]])
271
259
}
272
260
}
@@ -418,58 +406,66 @@ private[clustering] object LDA {
418
406
419
407
}
420
408
421
- // todo: add reference to paper and Hoffman
409
+ /**
410
+ * Optimizer for Online LDA algorithm which breaks corpus into mini-batches and scans only once.
411
+ * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
412
+ */
422
413
private [clustering] class OnlineLDAOptimizer (
423
- val documents : RDD [(Long , Vector )],
424
- val k : Int ,
425
- val vocabSize : Int ) extends Serializable {
414
+ private val documents : RDD [(Long , Vector )],
415
+ private val k : Int ) extends Serializable {
426
416
427
- private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
428
- private val tau0 = 1024 // down weights early iterations
429
- private val D = documents.count()
417
+ private val vocabSize = documents.first._2.size
418
+ private val D = documents.count().toInt
430
419
private val batchSize = if (D / 1000 > 4096 ) 4096
431
420
else if (D / 1000 < 4 ) 4
432
421
else D / 1000
433
- val batchNumber = (D / batchSize + 1 ).toInt
434
- private val batches = documents.sliding(batchNumber).collect()
422
+ val batchNumber = D / batchSize
435
423
436
424
// Initialize the variational distribution q(beta|lambda)
437
- var _lambda = getGammaMatrix(k, vocabSize) // K * V
438
- private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
439
- private var _expElogbeta = exp(_Elogbeta) // K * V
425
+ var lambda = getGammaMatrix(k, vocabSize) // K * V
426
+ private var Elogbeta = dirichlet_expectation(lambda) // K * V
427
+ private var expElogbeta = exp(Elogbeta ) // K * V
440
428
441
- private var batchCount = 0
429
+ private var batchId = 0
442
430
def next (): Unit = {
443
- // weight of the mini-batch.
444
- val rhot = math.pow(tau0 + batchCount, - kappa)
431
+ require(batchId < batchNumber)
432
+ // weight of the mini-batch. 1024 down weights early iterations
433
+ val weight = math.pow(1024 + batchId, - 0.5 )
434
+ val batch = documents.filter(doc => doc._1 % batchNumber == batchId)
445
435
436
+ // Given a mini-batch of documents, estimates the parameters gamma controlling the
437
+ // variational distribution over the topic weights for each document in the mini-batch.
446
438
var stat = BDM .zeros[Double ](k, vocabSize)
447
- stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)
448
-
449
- stat = stat :* _expElogbeta
450
- _lambda = _lambda * (1 - rhot) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * rhot
451
- _Elogbeta = dirichlet_expectation(_lambda)
452
- _expElogbeta = exp(_Elogbeta)
453
- batchCount += 1
439
+ stat = batch.aggregate(stat)(seqOp, _ += _)
440
+ stat = stat :* expElogbeta
441
+
442
+ // Update lambda based on documents.
443
+ lambda = lambda * (1 - weight) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * weight
444
+ Elogbeta = dirichlet_expectation(lambda)
445
+ expElogbeta = exp(Elogbeta )
446
+ batchId += 1
454
447
}
455
448
456
- private def seqOp (other : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
449
+ // for each document d update that document's gamma and phi
450
+ private def seqOp (stat : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
457
451
val termCounts = doc._2
458
452
val (ids, cts) = termCounts match {
459
453
case v : DenseVector => (((0 until v.size).toList), v.values)
460
454
case v : SparseVector => (v.indices.toList, v.values)
461
455
case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
462
456
}
463
457
458
+ // Initialize the variational distribution q(theta|gamma) for the mini-batch
464
459
var gammad = new Gamma (100 , 1.0 / 100.0 ).samplesVector(k).t // 1 * K
465
460
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
466
461
var expElogthetad = exp(Elogthetad .t).t // 1 * K
467
- val expElogbetad = _expElogbeta (:: , ids).toDenseMatrix // K * ids
462
+ val expElogbetad = expElogbeta (:: , ids).toDenseMatrix // K * ids
468
463
469
464
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
470
465
var meanchange = 1D
471
- val ctsVector = new BDV [Double ](cts).t // 1 * ids
466
+ val ctsVector = new BDV [Double ](cts).t // 1 * ids
472
467
468
+ // Iterate between gamma and phi until convergence
473
469
while (meanchange > 1e-6 ) {
474
470
val lastgamma = gammad
475
471
// 1*K 1 * ids ids * k
@@ -480,30 +476,30 @@ private[clustering] object LDA {
480
476
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
481
477
}
482
478
483
- val v1 = expElogthetad.t.toDenseMatrix.t
484
- val v2 = (ctsVector / phinorm).t.toDenseMatrix
485
- val outerResult = kron(v1, v2 ) // K * ids
479
+ val m1 = expElogthetad.t.toDenseMatrix.t
480
+ val m2 = (ctsVector / phinorm).t.toDenseMatrix
481
+ val outerResult = kron(m1, m2 ) // K * ids
486
482
for (i <- 0 until ids.size) {
487
- other (:: , ids(i)) := (other (:: , ids(i)) + outerResult(:: , i))
483
+ stat (:: , ids(i)) := (stat (:: , ids(i)) + outerResult(:: , i))
488
484
}
489
- other
485
+ stat
490
486
}
491
487
492
- def getGammaMatrix (row: Int , col: Int ): BDM [Double ] = {
488
+ private def getGammaMatrix (row: Int , col: Int ): BDM [Double ] = {
493
489
val gammaRandomGenerator = new Gamma (100 , 1.0 / 100.0 )
494
490
val temp = gammaRandomGenerator.sample(row * col).toArray
495
491
(new BDM [Double ](col, row, temp)).t
496
492
}
497
493
498
- def dirichlet_expectation (alpha : BDM [Double ]): BDM [Double ] = {
494
+ private def dirichlet_expectation (alpha : BDM [Double ]): BDM [Double ] = {
499
495
val rowSum = sum(alpha(breeze.linalg.* , :: ))
500
496
val digAlpha = digamma(alpha)
501
497
val digRowSum = digamma(rowSum)
502
498
val result = digAlpha(:: , breeze.linalg.* ) - digRowSum
503
499
result
504
500
}
505
501
506
- def vector_dirichlet_expectation (v : BDV [Double ]): (BDV [Double ]) = {
502
+ private def vector_dirichlet_expectation (v : BDV [Double ]): (BDV [Double ]) = {
507
503
digamma(v) - digamma(sum(v))
508
504
}
509
505
}
0 commit comments