Skip to content

Commit bf4e685

Browse files
author
Feynman Liang
committed
Asymmetric docConcentration
1 parent 4cab972 commit bf4e685

File tree

4 files changed

+36
-21
lines changed

4 files changed

+36
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class LDA private (
7979
*
8080
* This is the parameter to a Dirichlet distribution.
8181
*/
82-
def getDocConcentration: Vector = this.docConcentration
82+
def getAsymmetricDocConcentration: Vector = this.docConcentration
8383

8484
/**
8585
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
@@ -105,22 +105,37 @@ class LDA private (
105105
* - default = uniformly (1.0 / k), following the implementation from
106106
* [[https://github.com/Blei-Lab/onlineldavb]].
107107
*/
108-
def setDocConcentration(docConcentration: Vector): this.type = {
108+
def setAsymmetricDocConcentration(docConcentration: Vector): this.type = {
109109
this.docConcentration = docConcentration
110110
this
111111
}
112112

113-
/** Replicates Double to create a symmetric prior */
113+
/**
114+
* Gets the concentration parameter, assuming the document-topic Dirichlet distribution is
115+
* symmetric. Included for backwards compatibility. This method should fail if
116+
* [[docConcentration]] is asymmetric.
117+
*/
118+
def getDocConcentration: Double = {
119+
val parameter = docConcentration(0)
120+
if (docConcentration.size == 1) {
121+
parameter
122+
} else {
123+
require(docConcentration.toArray.forall(_ == parameter))
124+
parameter
125+
}
126+
}
127+
128+
/** Replicates a [[Double]] docConcentration to create a symmetric prior. */
114129
def setDocConcentration(docConcentration: Double): this.type = {
115130
this.docConcentration = Vectors.dense(docConcentration)
116131
this
117132
}
118133

119-
/** Alias for [[getDocConcentration]] */
120-
def getAlpha: Vector = getDocConcentration
134+
/** Alias for [[setAsymmetricDocConcentration()]] */
135+
def getAlpha: Vector = getAsymmetricDocConcentration
121136

122-
/** Alias for [[setDocConcentration()]] */
123-
def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)
137+
/** Alias for [[setAsymmetricDocConcentration()]] */
138+
def setAlpha(alpha: Vector): this.type = setAsymmetricDocConcentration(alpha)
124139

125140
/** Alias for [[setDocConcentration()]] */
126141
def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
757757
val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
758758

759759
new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
760-
docConcentration, topicConcentration, gammaShape, iterationTimes)
760+
docConcentration, topicConcentration, iterationTimes, gammaShape)
761761
}
762762

763763
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,8 @@ final class EMLDAOptimizer extends LDAOptimizer {
9595
* Compute bipartite term/doc graph.
9696
*/
9797
override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
98-
val docConcentration = lda.getDocConcentration(0)
99-
require({
100-
lda.getDocConcentration.toArray.forall(_ == docConcentration)
101-
}, "EMLDAOptimizer currently only supports symmetric document-topic priors")
98+
// EMLDAOptimizer currently only supports symmetric document-topic priors
99+
val docConcentration = lda.getDocConcentration
102100

103101
val topicConcentration = lda.getTopicConcentration
104102
val k = lda.getK
@@ -378,18 +376,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
378376
this.k = lda.getK
379377
this.corpusSize = docs.count()
380378
this.vocabSize = docs.first()._2.size
381-
this.alpha = if (lda.getDocConcentration.size == 1) {
382-
if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
379+
this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) {
380+
if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
383381
else {
384-
require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
385-
Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
382+
require(lda.getAsymmetricDocConcentration(0) >= 0,
383+
s"all entries in alpha must be >=0, got: $alpha")
384+
Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0)))
386385
}
387386
} else {
388-
require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
389-
lda.getDocConcentration.foreachActive { case (_, x) =>
387+
require(lda.getAsymmetricDocConcentration.size == k,
388+
s"alpha must have length k, got: $alpha")
389+
lda.getAsymmetricDocConcentration.foreachActive { case (_, x) =>
390390
require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
391391
}
392-
lda.getDocConcentration
392+
lda.getAsymmetricDocConcentration
393393
}
394394
this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
395395
this.randomGenerator = new Random(lda.getSeed)

mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
161161
test("setter alias") {
162162
val lda = new LDA().setAlpha(2.0).setBeta(3.0)
163163
assert(lda.getAlpha.toArray.forall(_ === 2.0))
164-
assert(lda.getDocConcentration.toArray.forall(_ === 2.0))
164+
assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0))
165165
assert(lda.getBeta === 3.0)
166166
assert(lda.getTopicConcentration === 3.0)
167167
}
@@ -364,7 +364,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
364364
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
365365
.setGammaShape(1e10)
366366
val lda = new LDA().setK(2)
367-
.setDocConcentration(Vectors.dense(0.00001, 0.1))
367+
.setAsymmetricDocConcentration(Vectors.dense(0.00001, 0.1))
368368
.setTopicConcentration(0.01)
369369
.setMaxIterations(100)
370370
.setOptimizer(op)

0 commit comments

Comments
 (0)