@@ -247,9 +247,34 @@ class LDA private (
247
247
new DistributedLDAModel (state, iterationTimes)
248
248
}
249
249
250
- def runOnlineLDA (documents : RDD [(Long , Vector )]): LDAModel = {
251
- val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k)
252
- (0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
250
+
251
+ /**
252
+ * Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
253
+ * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
254
+ *
255
+ * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
256
+ * The term count vectors are "bags of words" with a fixed-size vocabulary
257
+ * (where the vocabulary size is the length of the vector).
258
+ * Document IDs must be unique and >= 0.
259
+ * @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
260
+ * -1 for automatic batchNumber.
261
+ * @return Inferred LDA model
262
+ */
263
+ def runOnlineLDA (documents : RDD [(Long , Vector )], batchNumber : Int = - 1 ): LDAModel = {
264
+ val D = documents.count().toInt
265
+ val batchSize =
266
+ if (batchNumber == - 1 ) { // auto mode
267
+ if (D / 100 > 16384 ) 16384
268
+ else if (D / 100 < 4 ) 4
269
+ else D / 100
270
+ }
271
+ else {
272
+ require(batchNumber > 0 , " batchNumber should be positive or -1" )
273
+ D / batchNumber
274
+ }
275
+
276
+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, batchSize)
277
+ (0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next())
253
278
new LocalLDAModel (Matrices .fromBreeze(onlineLDA.lambda).transpose)
254
279
}
255
280
@@ -411,28 +436,26 @@ private[clustering] object LDA {
411
436
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
412
437
*/
413
438
private [clustering] class OnlineLDAOptimizer (
414
- private val documents : RDD [(Long , Vector )],
415
- private val k : Int ) extends Serializable {
439
+ private val documents : RDD [(Long , Vector )],
440
+ private val k : Int ,
441
+ private val batchSize : Int ) extends Serializable {
416
442
417
443
private val vocabSize = documents.first._2.size
418
444
private val D = documents.count().toInt
419
- private val batchSize = if (D / 1000 > 4096 ) 4096
420
- else if (D / 1000 < 4 ) 4
421
- else D / 1000
422
- val batchNumber = D / batchSize
445
+ val actualBatchNumber = Math .ceil(D .toDouble / batchSize).toInt
423
446
424
- // Initialize the variational distribution q(beta|lambda)
447
+ // Initialize the variational distribution q(beta|lambda)
425
448
var lambda = getGammaMatrix(k, vocabSize) // K * V
426
449
private var Elogbeta = dirichlet_expectation(lambda) // K * V
427
450
private var expElogbeta = exp(Elogbeta ) // K * V
428
451
429
452
private var batchId = 0
430
453
def next (): Unit = {
431
- require(batchId < batchNumber )
454
+ require(batchId < actualBatchNumber )
432
455
// weight of the mini-batch. 1024 down weights early iterations
433
456
val weight = math.pow(1024 + batchId, - 0.5 )
434
- val batch = documents.filter(doc => doc._1 % batchNumber == batchId )
435
-
457
+ val batch = documents.sample( true , batchSize.toDouble / D )
458
+ batch.cache()
436
459
// Given a mini-batch of documents, estimates the parameters gamma controlling the
437
460
// variational distribution over the topic weights for each document in the mini-batch.
438
461
var stat = BDM .zeros[Double ](k, vocabSize)
0 commit comments