Skip to content

Commit 720b5a3

Browse files
author
Liquan Pei
committed
Add test for Word2Vec algorithm, minor fixes
1 parent 2e92b59 commit 720b5a3

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,27 @@ private case class VocabWord(
5050
* natural language processing and machine learning algorithms.
5151
*
5252
* We used skip-gram model in our implementation and hierarchical softmax
53-
* method to train the model.
53+
* method to train the model. The variable names in the implementation
54+
* mathes the original C implementation.
5455
*
5556
* For original C implementation, see https://code.google.com/p/word2vec/
5657
* For research papers, see
5758
* Efficient Estimation of Word Representations in Vector Space
5859
* and
59-
* Distributed Representations of Words and Phrases and their Compositionality
60+
* Distributed Representations of Words and Phrases and their Compositionality.
6061
* @param size vector dimension
6162
* @param startingAlpha initial learning rate
6263
* @param window context words from [-window, window]
6364
* @param minCount minimum frequncy to consider a vocabulary word
65+
* @param parallelisum number of partitions to run Word2Vec
6466
*/
6567
@Experimental
6668
class Word2Vec(
6769
val size: Int,
6870
val startingAlpha: Double,
6971
val window: Int,
70-
val minCount: Int)
72+
val minCount: Int,
73+
val parallelism:Int = 1)
7174
extends Serializable with Logging {
7275

7376
private val EXP_TABLE_SIZE = 1000
@@ -237,7 +240,7 @@ class Word2Vec(
237240
}
238241
}
239242

240-
val newSentences = sentences.repartition(1).cache()
243+
val newSentences = sentences.repartition(parallelism).cache()
241244
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
242245
val (aggSyn0, _, _, _) =
243246
// TODO: broadcast temp instead of serializing it directly
@@ -248,7 +251,7 @@ class Word2Vec(
248251
var wc = wordCount
249252
if (wordCount - lastWordCount > 10000) {
250253
lwc = wordCount
251-
alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1))
254+
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
252255
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
253256
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
254257
}
@@ -296,7 +299,7 @@ class Word2Vec(
296299
val n = syn0_1.length
297300
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
298301
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
299-
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
302+
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
300303
})
301304

302305
val wordMap = new Array[(String, Array[Double])](vocabSize)
@@ -309,7 +312,7 @@ class Word2Vec(
309312
i += 1
310313
}
311314
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
312-
.partitionBy(new HashPartitioner(modelPartitionNum))
315+
.partitionBy(new HashPartitioner(modelPartitionNum)).cache()
313316
new Word2VecModel(modelRDD)
314317
}
315318
}

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,27 @@ import org.apache.spark.SparkContext._
2323
import org.apache.spark.mllib.util.LocalSparkContext
2424

2525
class Word2VecSuite extends FunSuite with LocalSparkContext {
26-
test("word2vec") {
26+
test("Word2Vec") {
27+
val sentence = "a b " * 100 + "a c " * 10
28+
val localDoc = Seq(sentence, sentence)
29+
val doc = sc.parallelize(localDoc)
30+
.map(line => line.split(" ").toSeq)
31+
val size = 10
32+
val startingAlpha = 0.025
33+
val window = 2
34+
val minCount = 2
35+
val num = 2
36+
val word = "a"
37+
38+
val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
39+
val synons = model.findSynonyms("a", 2)
40+
assert(synons.length == num)
41+
assert(synons(0)._1 == "b")
42+
assert(synons(1)._1 == "c")
43+
}
44+
45+
46+
test("Word2VecModel") {
2747
val num = 2
2848
val localModel = Seq(
2949
("china" , Array(0.50, 0.50, 0.50, 0.50)),

0 commit comments

Comments
 (0)