Skip to content

Commit a570c9a

Browse files
committed
use sample to pick up batch
1 parent 4a3f27e commit a570c9a

File tree

1 file changed

+36
-13
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+36
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

+36-13
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,34 @@ class LDA private (
247247
new DistributedLDAModel(state, iterationTimes)
248248
}
249249

250-
def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = {
251-
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k)
252-
(0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
250+
251+
/**
252+
* Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
253+
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
254+
*
255+
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
256+
* The term count vectors are "bags of words" with a fixed-size vocabulary
257+
* (where the vocabulary size is the length of the vector).
258+
* Document IDs must be unique and >= 0.
259+
* @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
260+
* -1 for automatic batchNumber.
261+
* @return Inferred LDA model
262+
*/
263+
def runOnlineLDA(documents: RDD[(Long, Vector)], batchNumber: Int = -1): LDAModel = {
264+
val D = documents.count().toInt
265+
val batchSize =
266+
if (batchNumber == -1) { // auto mode
267+
if (D / 100 > 16384) 16384
268+
else if (D / 100 < 4) 4
269+
else D / 100
270+
}
271+
else {
272+
require(batchNumber > 0, "batchNumber should be positive or -1")
273+
D / batchNumber
274+
}
275+
276+
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchSize)
277+
(0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next())
253278
new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
254279
}
255280

@@ -411,28 +436,26 @@ private[clustering] object LDA {
411436
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
412437
*/
413438
private[clustering] class OnlineLDAOptimizer(
414-
private val documents: RDD[(Long, Vector)],
415-
private val k: Int) extends Serializable{
439+
private val documents: RDD[(Long, Vector)],
440+
private val k: Int,
441+
private val batchSize: Int) extends Serializable{
416442

417443
private val vocabSize = documents.first._2.size
418444
private val D = documents.count().toInt
419-
private val batchSize = if (D / 1000 > 4096) 4096
420-
else if (D / 1000 < 4) 4
421-
else D / 1000
422-
val batchNumber = D/batchSize
445+
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
423446

424-
// Initialize the variational distribution q(beta|lambda)
447+
//Initialize the variational distribution q(beta|lambda)
425448
var lambda = getGammaMatrix(k, vocabSize) // K * V
426449
private var Elogbeta = dirichlet_expectation(lambda) // K * V
427450
private var expElogbeta = exp(Elogbeta) // K * V
428451

429452
private var batchId = 0
430453
def next(): Unit = {
431-
require(batchId < batchNumber)
454+
require(batchId < actualBatchNumber)
432455
// weight of the mini-batch. 1024 down weights early iterations
433456
val weight = math.pow(1024 + batchId, -0.5)
434-
val batch = documents.filter(doc => doc._1 % batchNumber == batchId)
435-
457+
val batch = documents.sample(true, batchSize.toDouble / D)
458+
batch.cache()
436459
// Given a mini-batch of documents, estimates the parameters gamma controlling the
437460
// variational distribution over the topic weights for each document in the mini-batch.
438461
var stat = BDM.zeros[Double](k, vocabSize)

0 commit comments

Comments
 (0)