Skip to content

Commit 7148426

Browse files
author
Li Pu
committed
improve RowMatrix multiply
1 parent 5543cce commit 7148426

File tree

1 file changed

+12
-3
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed

1 file changed

+12
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,15 +509,24 @@ class RowMatrix(
509509
*/
510510
def multiply(B: Matrix): RowMatrix = {
511511
val n = numCols().toInt
512+
val k = B.numCols
512513
require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}")
513514

514515
require(B.isInstanceOf[DenseMatrix],
515516
s"Only support dense matrix at this time but found ${B.getClass.getName}.")
516517

517-
val Bb = rows.context.broadcast(B)
518+
val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray)
518519
val AB = rows.mapPartitions({ iter =>
519-
val Bi = Bb.value.toBreeze.asInstanceOf[BDM[Double]]
520-
iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze))
520+
val Bi = Bb.value
521+
iter.map(row => {
522+
val v = BDV.zeros[Double](k)
523+
var i = 0
524+
while (i < k) {
525+
v(i) = row.toBreeze.dot(new BDV(Bi, i * n, 1, n))
526+
i += 1
527+
}
528+
Vectors.fromBreeze(v)
529+
})
521530
}, preservesPartitioning = true)
522531

523532
new RowMatrix(AB, nRows, B.numCols)

0 commit comments

Comments
 (0)