@@ -431,16 +431,23 @@ class Word2Vec extends Serializable with Logging {
431431class Word2VecModel private [mllib] (
432432 model : Map [String , Array [Float ]]) extends Serializable with Saveable {
433433
434- val indexedModel = model.keys.zip(0 until model.size).toMap
434+ // Maintain a ordered list of words based on the index in the initial model.
435+ private val wordList : Array [String ] = model.keys.toArray
436+ private val wordIndex : Map [String , Int ] = wordList.zip(0 until model.size).toMap
435437
436- private val (wordVectors, wordVecNorms) = {
438+ private val (wordVectors : DenseMatrix , wordVecNorms : Array [ Double ] ) = {
437439 val numDim = model.head._2.size
438- val numWords = indexedModel .size
440+ val numWords = wordIndex .size
439441 val flatVec = model.toSeq.flatMap { case (w, v) =>
440442 v.map(_.toDouble)}.toArray
441443 val wordVectors = new DenseMatrix (numWords, numDim, flatVec, isTransposed= true )
442- val wordVecNorms = model.map { case (word, vec) =>
443- blas.snrm2(numDim, vec, 1 )}.toArray
444+ val wordVecNorms = new Array [Double ](numWords)
445+ var i = 0
446+ while (i < numWords) {
447+ val vec = model.get(wordList(i)).get
448+ wordVecNorms(i) = blas.snrm2(numDim, vec, 1 )
449+ i += 1
450+ }
444451 (wordVectors, wordVecNorms)
445452 }
446453
@@ -495,13 +502,16 @@ class Word2VecModel private[mllib] (
495502
496503 val numWords = wordVectors.numRows
497504 val cosineVec = Vectors .zeros(numWords).asInstanceOf [DenseVector ]
498- BLAS .gemv(1.0 , wordVectors, vector.asInstanceOf [ DenseVector ] , 0.0 , cosineVec)
505+ BLAS .gemv(1.0 , wordVectors, new DenseVector ( vector.toArray) , 0.0 , cosineVec)
499506
500507 // Need not divide with the norm of the given vector since it is constant.
501- val updatedCosines = indexedModel.map { case (_, ind) =>
502- cosineVec(ind) / wordVecNorms(ind) }
503-
504- indexedModel.keys.zip(updatedCosines)
508+ val updatedCosines = new Array [Double ](numWords)
509+ var ind = 0
510+ while (ind < numWords) {
511+ updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
512+ ind += 1
513+ }
514+ wordList.zip(updatedCosines)
505515 .toSeq
506516 .sortBy(- _._2)
507517 .take(num + 1 )
@@ -514,7 +524,7 @@ class Word2VecModel private[mllib] (
514524 */
515525 def getVectors : Map [String , Array [Float ]] = {
516526 val numDim = wordVectors.numCols
517- indexedModel .map { case (word, ind) =>
527+ wordIndex .map { case (word, ind) =>
518528 val startInd = numDim * ind
519529 val endInd = startInd + numDim
520530 (word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }
0 commit comments