Skip to content

Commit fbe0108

Browse files
committed
Made the following changes
1. Calculate norms during initialization. 2. Use Blas calls from linalg.blas
1 parent 1350cf3 commit fbe0108

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.SparkContext
3434
import org.apache.spark.SparkContext._
3535
import org.apache.spark.annotation.Experimental
3636
import org.apache.spark.api.java.JavaRDD
37-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
37+
import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector}
3838
import org.apache.spark.mllib.util.{Loader, Saveable}
3939
import org.apache.spark.rdd._
4040
import org.apache.spark.util.Utils
@@ -431,6 +431,14 @@ class Word2Vec extends Serializable with Logging {
431431
class Word2VecModel private[mllib] (
432432
private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
433433

434+
private val numDim = model.head._2.size
435+
private val numWords = model.size
436+
private val flatVec = model.toSeq.flatMap { case(w, v) =>
437+
v.map(_.toDouble)}.toArray
438+
private val wordVecMat = new DenseMatrix(numWords, numDim, flatVec, isTransposed=true)
439+
private val wordVecNorms = model.map { case (word, vec) =>
440+
blas.snrm2(numDim, vec, 1)}.toArray
441+
434442
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
435443
require(v1.length == v2.length, "Vectors should have the same length")
436444
val n = v1.length
@@ -481,19 +489,13 @@ class Word2VecModel private[mllib] (
481489
require(num > 0, "Number of similar words should > 0")
482490

483491
val fVector = vector.toArray
484-
val flatVec = model.toSeq.flatMap { case(w, v) =>
485-
v.map(_.toDouble)}.toArray
486-
487-
val numDim = model.head._2.size
488-
val numWords = model.size
489-
val cosineArray = Array.fill[Double](numWords)(0)
490492

491-
blas.dgemv(
492-
"T", numDim, numWords, 1.0, flatVec, numDim, fVector, 1, 0.0, cosineArray, 1)
493+
val cosineVec = new DenseVector(Array.fill[Double](numWords)(0))
494+
BLAS.gemv(1.0, wordVecMat, vector.asInstanceOf[DenseVector], 0.0, cosineVec)
493495

494496
// Need not divide with the norm of the given vector since it is constant.
495497
val updatedCosines = model.zipWithIndex.map { case (vec, ind) =>
496-
cosineArray(ind) / blas.snrm2(numDim, vec._2, 1) }
498+
cosineVec(ind) / wordVecNorms(ind) }
497499

498500
model.keys.zip(updatedCosines)
499501
.toSeq

0 commit comments

Comments
 (0)