Skip to content

Commit cb5d823

Browse files
committed
add predict for single doc
1 parent 2692bdb commit cb5d823

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,34 @@ class LocalLDAModel private[clustering] (
366366
}
367367
}
368368

369+
/**
370+
* Predicts the topic mixture distribution for a document (often called "theta" in the
371+
* literature). Returns a vector of zeros for an empty document.
372+
*
373+
* Note this means to allow quick query for single document. For batch documents, please refer
374+
* to [[topicDistributions(documents: RDD[(Long, Vector)])]] to avoid overhead.
375+
*
376+
* @param document document to predict topic mixture distributions for
377+
* @return (document ID, topic mixture distribution for document)
378+
*/
379+
@Since("1.6.0")
380+
def topicDistributions(document: (Long, Vector)): (Long, Vector) = {
381+
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
382+
val id = document._1
383+
val termCounts = document._2
384+
if (termCounts.numNonzeros == 0) {
385+
(id, Vectors.zeros(this.k))
386+
} else {
387+
val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
388+
termCounts,
389+
expElogbeta,
390+
this.docConcentration.toBreeze,
391+
gammaShape,
392+
this.k)
393+
(id, Vectors.dense(normalize(gamma, 1.0).toArray))
394+
}
395+
}
396+
369397
/**
370398
* Java-friendly version of [[topicDistributions]]
371399
*/

mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
374374
.values
375375
.collect()
376376

377-
expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
378-
expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
377+
expectedPredictions.zip(actualPredictions).foreach { case (expected, actual) =>
378+
assert(expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D))
379379
}
380+
381+
val topicsBz = ldaModel.topicDistributions(docs.first())._2.toBreeze.toDenseVector
382+
val singlePrediction = (argmax(topicsBz), max(topicsBz))
383+
assert(expectedPredictions(0)._1 === singlePrediction._1
384+
&& (expectedPredictions(0)._2 ~== singlePrediction._2 relTol 1E-3D))
380385
}
381386

382387
test("OnlineLDAOptimizer with asymmetric prior") {

0 commit comments

Comments
 (0)