Skip to content

Commit da1642d

Browse files
committed
Explicit types and indexing
1 parent 64575b0 commit da1642d

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -431,16 +431,23 @@ class Word2Vec extends Serializable with Logging {
431431
class Word2VecModel private[mllib] (
432432
model: Map[String, Array[Float]]) extends Serializable with Saveable {
433433

434-
val indexedModel = model.keys.zip(0 until model.size).toMap
434+
// Maintain a ordered list of words based on the index in the initial model.
435+
private val wordList: Array[String] = model.keys.toArray
436+
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
435437

436-
private val (wordVectors, wordVecNorms) = {
438+
private val (wordVectors: DenseMatrix, wordVecNorms: Array[Double]) = {
437439
val numDim = model.head._2.size
438-
val numWords = indexedModel.size
440+
val numWords = wordIndex.size
439441
val flatVec = model.toSeq.flatMap { case(w, v) =>
440442
v.map(_.toDouble)}.toArray
441443
val wordVectors = new DenseMatrix(numWords, numDim, flatVec, isTransposed=true)
442-
val wordVecNorms = model.map { case (word, vec) =>
443-
blas.snrm2(numDim, vec, 1)}.toArray
444+
val wordVecNorms = new Array[Double](numWords)
445+
var i = 0
446+
while (i < numWords) {
447+
val vec = model.get(wordList(i)).get
448+
wordVecNorms(i) = blas.snrm2(numDim, vec, 1)
449+
i += 1
450+
}
444451
(wordVectors, wordVecNorms)
445452
}
446453

@@ -495,13 +502,16 @@ class Word2VecModel private[mllib] (
495502

496503
val numWords = wordVectors.numRows
497504
val cosineVec = Vectors.zeros(numWords).asInstanceOf[DenseVector]
498-
BLAS.gemv(1.0, wordVectors, vector.asInstanceOf[DenseVector], 0.0, cosineVec)
505+
BLAS.gemv(1.0, wordVectors, new DenseVector(vector.toArray), 0.0, cosineVec)
499506

500507
// Need not divide with the norm of the given vector since it is constant.
501-
val updatedCosines = indexedModel.map { case (_, ind) =>
502-
cosineVec(ind) / wordVecNorms(ind) }
503-
504-
indexedModel.keys.zip(updatedCosines)
508+
val updatedCosines = new Array[Double](numWords)
509+
var ind = 0
510+
while (ind < numWords) {
511+
updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
512+
ind += 1
513+
}
514+
wordList.zip(updatedCosines)
505515
.toSeq
506516
.sortBy(- _._2)
507517
.take(num + 1)
@@ -514,7 +524,7 @@ class Word2VecModel private[mllib] (
514524
*/
515525
def getVectors: Map[String, Array[Float]] = {
516526
val numDim = wordVectors.numCols
517-
indexedModel.map { case (word, ind) =>
527+
wordIndex.map { case (word, ind) =>
518528
val startInd = numDim * ind
519529
val endInd = startInd + numDim
520530
(word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }

0 commit comments

Comments
 (0)