Skip to content

[SPARK-12685] [MLlib] [Backport to 1.4]word2vec trainWordsCount gets overflow #10721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,21 @@ class Word2Vec extends Serializable with Logging {
/** context words from [-window, window] */
private val window = 5

private var trainWordsCount = 0
private var trainWordsCount = 0L
private var vocabSize = 0
@transient private var vocab: Array[VocabWord] = null
@transient private var vocabHash = mutable.HashMap.empty[String, Int]

private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn)

Expand All @@ -164,7 +164,7 @@ class Word2Vec extends Serializable with Logging {
trainWordsCount += vocab(a).cn
a += 1
}
logInfo("trainWordsCount = " + trainWordsCount)
logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
}

private def createExpTable(): Array[Float] = {
Expand Down Expand Up @@ -313,7 +313,7 @@ class Word2Vec extends Serializable with Logging {
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
val model = iter.foldLeft((syn0Global, syn1Global, 0L, 0L)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
Expand Down