@@ -33,7 +33,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
33
33
import org .apache .spark .rdd ._
34
34
import org .apache .spark .util .Utils
35
35
import org .apache .spark .util .random .XORShiftRandom
36
- import org .apache .spark .util .collection .PrimitiveKeyOpenHashMap
37
36
38
37
/**
39
38
* Entry in vocabulary
@@ -348,21 +347,21 @@ class Word2Vec extends Serializable with Logging {
348
347
}
349
348
val syn0Local = model._1
350
349
val syn1Local = model._2
351
- val synOut = new PrimitiveKeyOpenHashMap [ Int , Array [Float ]](vocabSize * 2 )
350
+ val synOut = mutable. ListBuffer .empty[( Int , Array [Float ])]
352
351
var index = 0
353
352
while (index < vocabSize) {
354
353
if (syn0Modify(index) != 0 ) {
355
- synOut.update( index, syn0Local.slice(index * vectorSize, (index + 1 ) * vectorSize))
354
+ synOut += (( index, syn0Local.slice(index * vectorSize, (index + 1 ) * vectorSize) ))
356
355
}
357
356
if (syn1Modify(index) != 0 ) {
358
- synOut.update (index + vocabSize,
359
- syn1Local.slice(index * vectorSize, (index + 1 ) * vectorSize))
357
+ synOut += ( (index + vocabSize,
358
+ syn1Local.slice(index * vectorSize, (index + 1 ) * vectorSize)))
360
359
}
361
360
index += 1
362
361
}
363
- Iterator ( synOut)
362
+ synOut.toIterator
364
363
}
365
- val synAgg = partial.flatMap(x => x). reduceByKey { case (v1, v2) =>
364
+ val synAgg = partial.reduceByKey { case (v1, v2) =>
366
365
blas.saxpy(vectorSize, 1.0f , v2, 1 , v1, 1 )
367
366
v1
368
367
}.collect()
0 commit comments