Skip to content

Commit 2979450

Browse files
MechCodernemccarthy
authored andcommitted
[SPARK-6065] [MLlib] Optimize word2vec.findSynonyms using blas calls
1. Use blas calls to find the dot product between two vectors. 2. Prevent re-computing the L2 norm of the given vector for each word in model. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes apache#5467 from MechCoder/spark-6065 and squashes the following commits: dd0b0b2 [MechCoder] Preallocate wordVectors ffc9240 [MechCoder] Minor 6b74c81 [MechCoder] Switch back to native blas calls da1642d [MechCoder] Explicit types and indexing 64575b0 [MechCoder] Save indexedmap and a wordvecmat instead of matrix fbe0108 [MechCoder] Made the following changes 1. Calculate norms during initialization. 2. Use Blas calls from linalg.blas 1350cf3 [MechCoder] [SPARK-6065] Optimize word2vec.findSynonynms using blas calls
1 parent b30d4ec commit 2979450

File tree

1 file changed

+51
-6
lines changed

1 file changed

+51
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.SparkContext
3434
import org.apache.spark.SparkContext._
3535
import org.apache.spark.annotation.Experimental
3636
import org.apache.spark.api.java.JavaRDD
37-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
37+
import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector}
3838
import org.apache.spark.mllib.util.{Loader, Saveable}
3939
import org.apache.spark.rdd._
4040
import org.apache.spark.util.Utils
@@ -429,7 +429,36 @@ class Word2Vec extends Serializable with Logging {
429429
*/
430430
@Experimental
431431
class Word2VecModel private[mllib] (
432-
private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
432+
model: Map[String, Array[Float]]) extends Serializable with Saveable {
433+
434+
// wordList: Ordered list of words obtained from model.
435+
private val wordList: Array[String] = model.keys.toArray
436+
437+
// wordIndex: Maps each word to an index, which can retrieve the corresponding
438+
// vector from wordVectors (see below).
439+
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
440+
441+
// vectorSize: Dimension of each word's vector.
442+
private val vectorSize = model.head._2.size
443+
private val numWords = wordIndex.size
444+
445+
// wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
446+
// mapped with index i can be retrieved by the slice
447+
// (ind * vectorSize, ind * vectorSize + vectorSize)
448+
// wordVecNorms: Array of length numWords, each value being the Euclidean norm
449+
// of the wordVector.
450+
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
451+
val wordVectors = new Array[Float](vectorSize * numWords)
452+
val wordVecNorms = new Array[Double](numWords)
453+
var i = 0
454+
while (i < numWords) {
455+
val vec = model.get(wordList(i)).get
456+
Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize)
457+
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
458+
i += 1
459+
}
460+
(wordVectors, wordVecNorms)
461+
}
433462

434463
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
435464
require(v1.length == v2.length, "Vectors should have the same length")
@@ -443,7 +472,7 @@ class Word2VecModel private[mllib] (
443472
override protected def formatVersion = "1.0"
444473

445474
def save(sc: SparkContext, path: String): Unit = {
446-
Word2VecModel.SaveLoadV1_0.save(sc, path, model)
475+
Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
447476
}
448477

449478
/**
@@ -479,9 +508,23 @@ class Word2VecModel private[mllib] (
479508
*/
480509
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
481510
require(num > 0, "Number of similar words should > 0")
482-
// TODO: optimize top-k
511+
483512
val fVector = vector.toArray.map(_.toFloat)
484-
model.mapValues(vec => cosineSimilarity(fVector, vec))
513+
val cosineVec = Array.fill[Float](numWords)(0)
514+
val alpha: Float = 1
515+
val beta: Float = 0
516+
517+
blas.sgemv(
518+
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
519+
520+
// Need not divide with the norm of the given vector since it is constant.
521+
val updatedCosines = new Array[Double](numWords)
522+
var ind = 0
523+
while (ind < numWords) {
524+
updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
525+
ind += 1
526+
}
527+
wordList.zip(updatedCosines)
485528
.toSeq
486529
.sortBy(- _._2)
487530
.take(num + 1)
@@ -493,7 +536,9 @@ class Word2VecModel private[mllib] (
493536
* Returns a map of words to their vector representations.
494537
*/
495538
def getVectors: Map[String, Array[Float]] = {
496-
model
539+
wordIndex.map { case (word, ind) =>
540+
(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
541+
}
497542
}
498543
}
499544

0 commit comments

Comments
 (0)