File tree Expand file tree Collapse file tree 2 files changed +35
-2
lines changed
main/scala/org/apache/spark/mllib/clustering
test/scala/org/apache/spark/mllib/clustering Expand file tree Collapse file tree 2 files changed +35
-2
lines changed Original file line number Diff line number Diff line change @@ -366,6 +366,34 @@ class LocalLDAModel private[clustering] (
366
366
}
367
367
}
368
368
369
+ /**
370
+ * Predicts the topic mixture distribution for a document (often called "theta" in the
371
+ * literature). Returns a vector of zeros for an empty document.
372
+ *
373
+ * Note this means to allow quick query for single document. For batch documents, please refer
374
+ * to [[topicDistributions(documents: RDD[(Long, Vector)])]] to avoid overhead.
375
+ *
376
+ * @param document document to predict topic mixture distributions for
377
+ * @return (document ID, topic mixture distribution for document)
378
+ */
379
+ @ Since (" 1.6.0" )
380
+ def topicDistributions (document : (Long , Vector )): (Long , Vector ) = {
381
+ val expElogbeta = exp(LDAUtils .dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
382
+ val id = document._1
383
+ val termCounts = document._2
384
+ if (termCounts.numNonzeros == 0 ) {
385
+ (id, Vectors .zeros(this .k))
386
+ } else {
387
+ val (gamma, _) = OnlineLDAOptimizer .variationalTopicInference(
388
+ termCounts,
389
+ expElogbeta,
390
+ this .docConcentration.toBreeze,
391
+ gammaShape,
392
+ this .k)
393
+ (id, Vectors .dense(normalize(gamma, 1.0 ).toArray))
394
+ }
395
+ }
396
+
369
397
/**
370
398
* Java-friendly version of [[topicDistributions ]]
371
399
*/
Original file line number Diff line number Diff line change @@ -374,9 +374,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
374
374
.values
375
375
.collect()
376
376
377
- expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
378
- expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D )
377
+ expectedPredictions.zip(actualPredictions).foreach { case (expected, actual) =>
378
+ assert( expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D ) )
379
379
}
380
+
381
+ val topicsBz = ldaModel.topicDistributions(docs.first())._2.toBreeze.toDenseVector
382
+ val singlePrediction = (argmax(topicsBz), max(topicsBz))
383
+ assert(expectedPredictions(0 )._1 === singlePrediction._1
384
+ && (expectedPredictions(0 )._2 ~== singlePrediction._2 relTol 1E-3D ))
380
385
}
381
386
382
387
test(" OnlineLDAOptimizer with asymmetric prior" ) {
You can’t perform that action at this time.
0 commit comments