Skip to content

Commit be3e271

Browse files
Feynman Liangjkbradley
authored andcommitted
[SPARK-9788] [MLLIB] Fix LDA Binary Compatibility
1. Add “asymmetricDocConcentration” and revert docConcentration changes. If the (internal) doc concentration vector is a single value, “getDocConcentration" returns it. If it is a constant vector, getDocConcentration returns the first item, and fails otherwise. 2. Give `LDAModel.gammaShape` a default value in `LDAModel` concrete class constructors. jkbradley Author: Feynman Liang <fliang@databricks.com> Closes #8077 from feynmanliang/SPARK-9788 and squashes the following commits: 6b07bc8 [Feynman Liang] Code review changes 9d6a71e [Feynman Liang] Add asymmetricAlpha alias bf4e685 [Feynman Liang] Asymmetric docConcentration 4cab972 [Feynman Liang] Default gammaShape
1 parent 423cdfd commit be3e271

File tree

4 files changed

+46
-24
lines changed

4 files changed

+46
-24
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,24 @@ class LDA private (
7979
*
8080
* This is the parameter to a Dirichlet distribution.
8181
*/
82-
def getDocConcentration: Vector = this.docConcentration
82+
def getAsymmetricDocConcentration: Vector = this.docConcentration
83+
84+
/**
85+
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
86+
* distributions over topics ("theta").
87+
*
88+
* This method assumes the Dirichlet distribution is symmetric and can be described by a single
89+
* [[Double]] parameter. It should fail if docConcentration is asymmetric.
90+
*/
91+
def getDocConcentration: Double = {
92+
val parameter = docConcentration(0)
93+
if (docConcentration.size == 1) {
94+
parameter
95+
} else {
96+
require(docConcentration.toArray.forall(_ == parameter))
97+
parameter
98+
}
99+
}
83100

84101
/**
85102
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
@@ -106,18 +123,22 @@ class LDA private (
106123
* [[https://github.com/Blei-Lab/onlineldavb]].
107124
*/
108125
def setDocConcentration(docConcentration: Vector): this.type = {
126+
require(docConcentration.size > 0, "docConcentration must have > 0 elements")
109127
this.docConcentration = docConcentration
110128
this
111129
}
112130

113-
/** Replicates Double to create a symmetric prior */
131+
/** Replicates a [[Double]] docConcentration to create a symmetric prior. */
114132
def setDocConcentration(docConcentration: Double): this.type = {
115133
this.docConcentration = Vectors.dense(docConcentration)
116134
this
117135
}
118136

137+
/** Alias for [[getAsymmetricDocConcentration]] */
138+
def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration
139+
119140
/** Alias for [[getDocConcentration]] */
120-
def getAlpha: Vector = getDocConcentration
141+
def getAlpha: Double = getDocConcentration
121142

122143
/** Alias for [[setDocConcentration()]] */
123144
def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.json4s.jackson.JsonMethods._
2727
import org.apache.spark.SparkContext
2828
import org.apache.spark.annotation.Experimental
2929
import org.apache.spark.api.java.JavaPairRDD
30-
import org.apache.spark.broadcast.Broadcast
3130
import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
3231
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
3332
import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -190,7 +189,8 @@ class LocalLDAModel private[clustering] (
190189
val topics: Matrix,
191190
override val docConcentration: Vector,
192191
override val topicConcentration: Double,
193-
override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable {
192+
override protected[clustering] val gammaShape: Double = 100)
193+
extends LDAModel with Serializable {
194194

195195
override def k: Int = topics.numCols
196196

@@ -455,8 +455,9 @@ class DistributedLDAModel private[clustering] (
455455
val vocabSize: Int,
456456
override val docConcentration: Vector,
457457
override val topicConcentration: Double,
458-
override protected[clustering] val gammaShape: Double,
459-
private[spark] val iterationTimes: Array[Double]) extends LDAModel {
458+
private[spark] val iterationTimes: Array[Double],
459+
override protected[clustering] val gammaShape: Double = 100)
460+
extends LDAModel {
460461

461462
import LDA._
462463

@@ -756,7 +757,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
756757
val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
757758

758759
new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
759-
docConcentration, topicConcentration, gammaShape, iterationTimes)
760+
docConcentration, topicConcentration, iterationTimes, gammaShape)
760761
}
761762

762763
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,8 @@ final class EMLDAOptimizer extends LDAOptimizer {
9595
* Compute bipartite term/doc graph.
9696
*/
9797
override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
98-
val docConcentration = lda.getDocConcentration(0)
99-
require({
100-
lda.getDocConcentration.toArray.forall(_ == docConcentration)
101-
}, "EMLDAOptimizer currently only supports symmetric document-topic priors")
98+
// EMLDAOptimizer currently only supports symmetric document-topic priors
99+
val docConcentration = lda.getDocConcentration
102100

103101
val topicConcentration = lda.getTopicConcentration
104102
val k = lda.getK
@@ -209,11 +207,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
209207
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
210208
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
211209
this.graphCheckpointer.deleteAllCheckpoints()
212-
// This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal
213-
// conversion
210+
// The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
211+
// LDAModel.toLocal conversion
214212
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
215213
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
216-
100, iterationTimes)
214+
iterationTimes)
217215
}
218216
}
219217

@@ -378,18 +376,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
378376
this.k = lda.getK
379377
this.corpusSize = docs.count()
380378
this.vocabSize = docs.first()._2.size
381-
this.alpha = if (lda.getDocConcentration.size == 1) {
382-
if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
379+
this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) {
380+
if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
383381
else {
384-
require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
385-
Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
382+
require(lda.getAsymmetricDocConcentration(0) >= 0,
383+
s"all entries in alpha must be >=0, got: $alpha")
384+
Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0)))
386385
}
387386
} else {
388-
require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
389-
lda.getDocConcentration.foreachActive { case (_, x) =>
387+
require(lda.getAsymmetricDocConcentration.size == k,
388+
s"alpha must have length k, got: $alpha")
389+
lda.getAsymmetricDocConcentration.foreachActive { case (_, x) =>
390390
require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
391391
}
392-
lda.getDocConcentration
392+
lda.getAsymmetricDocConcentration
393393
}
394394
this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
395395
this.randomGenerator = new Random(lda.getSeed)

mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
160160

161161
test("setter alias") {
162162
val lda = new LDA().setAlpha(2.0).setBeta(3.0)
163-
assert(lda.getAlpha.toArray.forall(_ === 2.0))
164-
assert(lda.getDocConcentration.toArray.forall(_ === 2.0))
163+
assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0))
164+
assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0))
165165
assert(lda.getBeta === 3.0)
166166
assert(lda.getTopicConcentration === 3.0)
167167
}

0 commit comments

Comments
 (0)