Skip to content

Various cleanups, use random seed, optimization #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Random

import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
import breeze.numerics.{digamma, exp, abs}
import breeze.stats.distributions.Gamma
import breeze.stats.distributions.{Gamma, RandBasis}

import org.apache.spark.annotation.Experimental
import org.apache.spark.graphx._
Expand Down Expand Up @@ -227,20 +227,37 @@ class OnlineLDAOptimizer extends LDAOptimizer {
private var k: Int = 0
private var corpusSize: Long = 0
private var vocabSize: Int = 0
private[clustering] var alpha: Double = 0
private[clustering] var eta: Double = 0

/** alias for docConcentration */
private var alpha: Double = 0

/** (private[clustering] for debugging) Get docConcentration */
private[clustering] def getAlpha: Double = alpha

/** alias for topicConcentration */
private var eta: Double = 0

/** (private[clustering] for debugging) Get topicConcentration */
private[clustering] def getEta: Double = eta

private var randomGenerator: java.util.Random = null

// Online LDA specific parameters
// Learning rate is: (tau_0 + t)^{-kappa}
private var tau_0: Double = 1024
private var kappa: Double = 0.51
private var miniBatchFraction: Double = 0.01
private var miniBatchFraction: Double = 0.05

// internal data structure
private var docs: RDD[(Long, Vector)] = null
private[clustering] var lambda: BDM[Double] = null

// count of invocation to next, which helps deciding the weight for each iteration
/** Dirichlet parameter for the posterior over topics */
private var lambda: BDM[Double] = null

/** (private[clustering] for debugging) Get parameter for topics */
private[clustering] def getLambda: BDM[Double] = lambda

/** Current iteration (count of invocations of [[next()]]) */
private var iteration: Int = 0
private var gammaShape: Double = 100

Expand Down Expand Up @@ -285,7 +302,12 @@ class OnlineLDAOptimizer extends LDAOptimizer {
/**
* Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in
* each iteration.
* Default: 0.01, i.e., 1% of total documents
*
* Note that this should be adjusted in synch with [[LDA.setMaxIterations()]]
* so the entire corpus is used. Specifically, set both so that
* maxIterations * miniBatchFraction >= 1.
*
* Default: 0.05, i.e., 5% of total documents.
*/
def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0,
Expand All @@ -295,15 +317,20 @@ class OnlineLDAOptimizer extends LDAOptimizer {
}

/**
* The function is for test only now. In the future, it can help support training stop/resume
* (private[clustering])
* Set the Dirichlet parameter for the posterior over topics.
* This is only used for testing now. In the future, it can help support training stop/resume.
*/
private[clustering] def setLambda(lambda: BDM[Double]): this.type = {
this.lambda = lambda
this
}

/**
* Used to control the gamma distribution. Larger value produces values closer to 1.0.
* (private[clustering])
* Used for random initialization of the variational parameters.
* Larger value produces values closer to 1.0.
* This is only used for testing currently.
*/
private[clustering] def setGammaShape(shape: Double): this.type = {
this.gammaShape = shape
Expand Down Expand Up @@ -380,12 +407,11 @@ class OnlineLDAOptimizer extends LDAOptimizer {
meanchange = sum(abs(gammad - lastgamma)) / k
}

val m1 = expElogthetad.t.toDenseMatrix.t
val m2 = (ctsVector / phinorm).t.toDenseMatrix
val outerResult = kron(m1, m2) // K * ids
val m1 = expElogthetad.t
val m2 = (ctsVector / phinorm).t.toDenseVector
var i = 0
while (i < ids.size) {
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
i += 1
}
}
Expand Down Expand Up @@ -423,7 +449,9 @@ class OnlineLDAOptimizer extends LDAOptimizer {
* Get a random matrix to initialize lambda
*/
private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)
val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
randomGenerator.nextLong()))
val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis)
val temp = gammaRandomGenerator.sample(row * col).toArray
new BDM[Double](col, row, temp).t
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.Serializable;
import java.util.ArrayList;

import org.apache.spark.api.java.JavaRDD;
import scala.Tuple2;

import org.junit.After;
Expand All @@ -30,6 +29,7 @@
import org.junit.Test;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
Expand Down Expand Up @@ -148,6 +148,6 @@ public void OnlineOptimizerCompatibility() {
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
LDASuite$.MODULE$.tinyTopicDescription();
JavaPairRDD<Long, Vector> corpus;
private JavaPairRDD<Long, Vector> corpus;

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {

// Check: describeTopics() with all terms
val fullTopicSummary = model.describeTopics()
assert(fullTopicSummary.size === tinyK)
assert(fullTopicSummary.length === tinyK)
fullTopicSummary.zip(tinyTopicDescription).foreach {
case ((algTerms, algTermWeights), (terms, termWeights)) =>
assert(algTerms === terms)
Expand Down Expand Up @@ -101,7 +101,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
// Check: per-doc topic distributions
val topicDistributions = model.topicDistributions.collect()
// Ensure all documents are covered.
assert(topicDistributions.size === tinyCorpus.size)
assert(topicDistributions.length === tinyCorpus.length)
assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
// Ensure we have proper distributions
topicDistributions.foreach { case (docId, topicDistribution) =>
Expand Down Expand Up @@ -139,8 +139,8 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
val corpus = sc.parallelize(tinyCorpus, 2)
val op = new OnlineLDAOptimizer().initialize(corpus, lda)
op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau_0(567)
assert(op.alpha == 0.5) // default 1.0 / k
assert(op.eta == 0.5) // default 1.0 / k
assert(op.getAlpha == 0.5) // default 1.0 / k
assert(op.getEta == 0.5) // default 1.0 / k
assert(op.getKappa == 0.9876)
assert(op.getMiniBatchFraction == 0.123)
assert(op.getTau_0 == 567)
Expand All @@ -154,14 +154,14 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {

def docs: Array[(Long, Vector)] = Array(
Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana
Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1))) // tiger, cat, dog
.zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
val corpus = sc.parallelize(docs, 2)

// setGammaShape large so to avoid the stochastic impact.
// Set GammaShape large to avoid the stochastic impact.
val op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51).setGammaShape(1e40)
.setMiniBatchFraction(1)
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op)
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345)

val state = op.initialize(corpus, lda)
// override lambda to simulate an intermediate state
Expand All @@ -175,8 +175,8 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {

// verify the result, Note this generate the identical result as
// [[https://github.com/Blei-Lab/onlineldavb]]
val topic1 = op.lambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic2 = op.lambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
}
Expand All @@ -186,7 +186,6 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
Vectors.sparse(6, Array(0, 1), Array(1, 1)),
Vectors.sparse(6, Array(1, 2), Array(1, 1)),
Vectors.sparse(6, Array(0, 2), Array(1, 1)),

Vectors.sparse(6, Array(3, 4), Array(1, 1)),
Vectors.sparse(6, Array(3, 5), Array(1, 1)),
Vectors.sparse(6, Array(4, 5), Array(1, 1))
Expand All @@ -200,6 +199,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
.setTopicConcentration(0.01)
.setMaxIterations(100)
.setOptimizer(op)
.setSeed(12345)

val ldaModel = lda.run(docs)
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
Expand All @@ -208,10 +208,10 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
}

// check distribution for each topic, typical distribution is (0.3, 0.3, 0.3, 0.02, 0.02, 0.02)
topics.foreach(topic =>{
val smalls = topic.filter(t => (t._2 < 0.1)).map(_._2)
assert(smalls.size == 3 && smalls.sum < 0.2)
})
topics.foreach { topic =>
val smalls = topic.filter(t => t._2 < 0.1).map(_._2)
assert(smalls.length == 3 && smalls.sum < 0.2)
}
}

}
Expand Down