@@ -50,24 +50,27 @@ private case class VocabWord(
50
50
* natural language processing and machine learning algorithms.
51
51
*
52
52
* We used skip-gram model in our implementation and hierarchical softmax
53
- * method to train the model.
53
+ * method to train the model. The variable names in the implementation
54
+ * mathes the original C implementation.
54
55
*
55
56
* For original C implementation, see https://code.google.com/p/word2vec/
56
57
* For research papers, see
57
58
* Efficient Estimation of Word Representations in Vector Space
58
59
* and
59
- * Distributed Representations of Words and Phrases and their Compositionality
60
+ * Distributed Representations of Words and Phrases and their Compositionality.
60
61
* @param size vector dimension
61
62
* @param startingAlpha initial learning rate
62
63
* @param window context words from [-window, window]
63
64
* @param minCount minimum frequncy to consider a vocabulary word
65
+ * @param parallelisum number of partitions to run Word2Vec
64
66
*/
65
67
@ Experimental
66
68
class Word2Vec (
67
69
val size : Int ,
68
70
val startingAlpha : Double ,
69
71
val window : Int ,
70
- val minCount : Int )
72
+ val minCount : Int ,
73
+ val parallelism : Int = 1 )
71
74
extends Serializable with Logging {
72
75
73
76
private val EXP_TABLE_SIZE = 1000
@@ -237,7 +240,7 @@ class Word2Vec(
237
240
}
238
241
}
239
242
240
- val newSentences = sentences.repartition(1 ).cache()
243
+ val newSentences = sentences.repartition(parallelism ).cache()
241
244
val temp = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
242
245
val (aggSyn0, _, _, _) =
243
246
// TODO: broadcast temp instead of serializing it directly
@@ -248,7 +251,7 @@ class Word2Vec(
248
251
var wc = wordCount
249
252
if (wordCount - lastWordCount > 10000 ) {
250
253
lwc = wordCount
251
- alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1 ))
254
+ alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
252
255
if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
253
256
logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
254
257
}
@@ -296,7 +299,7 @@ class Word2Vec(
296
299
val n = syn0_1.length
297
300
blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
298
301
blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
299
- (syn0_1, syn0_2 , lwc_1 + lwc_2, wc_1 + wc_2)
302
+ (syn0_1, syn1_1 , lwc_1 + lwc_2, wc_1 + wc_2)
300
303
})
301
304
302
305
val wordMap = new Array [(String , Array [Double ])](vocabSize)
@@ -309,7 +312,7 @@ class Word2Vec(
309
312
i += 1
310
313
}
311
314
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
312
- .partitionBy(new HashPartitioner (modelPartitionNum))
315
+ .partitionBy(new HashPartitioner (modelPartitionNum)).cache()
313
316
new Word2VecModel (modelRDD)
314
317
}
315
318
}
0 commit comments