Skip to content

Commit 1350cf3

Browse files
committed
[SPARK-6065] Optimize word2vec.findSynonynms using blas calls
1 parent 8220d52 commit 1350cf3

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,23 @@ class Word2VecModel private[mllib] (
479479
*/
480480
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
481481
require(num > 0, "Number of similar words should > 0")
482-
// TODO: optimize top-k
483-
val fVector = vector.toArray.map(_.toFloat)
484-
model.mapValues(vec => cosineSimilarity(fVector, vec))
482+
483+
val fVector = vector.toArray
484+
val flatVec = model.toSeq.flatMap { case(w, v) =>
485+
v.map(_.toDouble)}.toArray
486+
487+
val numDim = model.head._2.size
488+
val numWords = model.size
489+
val cosineArray = Array.fill[Double](numWords)(0)
490+
491+
blas.dgemv(
492+
"T", numDim, numWords, 1.0, flatVec, numDim, fVector, 1, 0.0, cosineArray, 1)
493+
494+
// Need not divide with the norm of the given vector since it is constant.
495+
val updatedCosines = model.zipWithIndex.map { case (vec, ind) =>
496+
cosineArray(ind) / blas.snrm2(numDim, vec._2, 1) }
497+
498+
model.keys.zip(updatedCosines)
485499
.toSeq
486500
.sortBy(- _._2)
487501
.take(num + 1)

0 commit comments

Comments
 (0)