Skip to content

Commit 050b1c5

Browse files
committed
output shuffle data directly
1 parent 1870dba commit 050b1c5

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging {
347347
}
348348
val syn0Local = model._1
349349
val syn1Local = model._2
350-
val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
351-
var index = 0
352-
while(index < vocabSize) {
353-
if (syn0Modify(index) != 0) {
354-
synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
350+
// Only output modified vectors.
351+
Iterator.tabulate(vocabSize) { index =>
352+
if (syn0Modify(index) > 0) {
353+
Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
354+
} else {
355+
None
355356
}
356-
if (syn1Modify(index) != 0) {
357-
synOut += ((index + vocabSize,
358-
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
357+
}.flatten ++ Iterator.tabulate(vocabSize) { index =>
358+
if (syn1Modify(index) > 0) {
359+
Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
360+
} else {
361+
None
359362
}
360-
index += 1
361-
}
362-
synOut.toIterator
363+
}.flatten
363364
}
364365
val synAgg = partial.reduceByKey { case (v1, v2) =>
365366
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)

0 commit comments

Comments
 (0)