Skip to content

Commit f378e16

Browse files
author
Burak Yavuz
committed
[SPARK-3974] Block Matrix Abstractions ready
1 parent b693209 commit f378e16

File tree

1 file changed

+85
-98
lines changed

1 file changed

+85
-98
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 85 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,32 @@ package org.apache.spark.mllib.linalg.distributed
2020
import breeze.linalg.{DenseMatrix => BDM}
2121

2222
import org.apache.spark._
23-
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices }
23+
import org.apache.spark.mllib.linalg.DenseMatrix
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.SparkContext._
2626
import org.apache.spark.storage.StorageLevel
2727
import org.apache.spark.util.Utils
2828

29-
case class BlockPartition(
30-
blockIdRow: Int,
31-
blockIdCol: Int,
32-
mat: DenseMatrix) extends Serializable
29+
/**
30+
* Represents a local matrix that makes up one block of a distributed BlockMatrix
31+
*
32+
* @param blockIdRow The row index of this block
33+
* @param blockIdCol The column index of this block
34+
* @param mat The underlying local matrix
35+
*/
36+
case class BlockPartition(blockIdRow: Int, blockIdCol: Int, mat: DenseMatrix) extends Serializable
3337

34-
// Information about BlockMatrix maintained on the driver
38+
/**
39+
* Information about the BlockMatrix maintained on the driver
40+
*
41+
* @param partitionId The id of the partition the block is found in
42+
* @param blockIdRow The row index of this block
43+
* @param blockIdCol The column index of this block
44+
* @param startRow The starting row index with respect to the distributed BlockMatrix
45+
* @param numRows The number of rows in this block
46+
* @param startCol The starting column index with respect to the distributed BlockMatrix
47+
* @param numCols The number of columns in this block
48+
*/
3549
case class BlockPartitionInfo(
3650
partitionId: Int,
3751
blockIdRow: Int,
@@ -41,6 +55,13 @@ case class BlockPartitionInfo(
4155
startCol: Long,
4256
numCols: Int) extends Serializable
4357

58+
/**
59+
* A partitioner that decides how the matrix is distributed in the cluster
60+
*
61+
* @param numPartitions Number of partitions
62+
* @param rowPerBlock Number of rows that make up each block.
63+
* @param colPerBlock Number of columns that make up each block.
64+
*/
4465
abstract class BlockMatrixPartitioner(
4566
override val numPartitions: Int,
4667
val rowPerBlock: Int,
@@ -52,6 +73,14 @@ abstract class BlockMatrixPartitioner(
5273
}
5374
}
5475

76+
/**
77+
* A grid partitioner, which stores every block in a separate partition.
78+
*
79+
* @param numRowBlocks Number of blocks that form the rows of the matrix.
80+
* @param numColBlocks Number of blocks that form the columns of the matrix.
81+
* @param rowPerBlock Number of rows that make up each block.
82+
* @param colPerBlock Number of columns that make up each block.
83+
*/
5584
class GridPartitioner(
5685
val numRowBlocks: Int,
5786
val numColBlocks: Int,
@@ -74,6 +103,14 @@ class GridPartitioner(
74103
}
75104
}
76105

106+
/**
107+
* A specialized partitioner that stores all blocks in the same row in just one partition.
108+
*
109+
* @param numPartitions Number of partitions. Should be set as the number of blocks that form
110+
* the rows of the matrix.
111+
* @param rowPerBlock Number of rows that make up each block.
112+
* @param colPerBlock Number of columns that make up each block.
113+
*/
77114
class RowBasedPartitioner(
78115
override val numPartitions: Int,
79116
override val rowPerBlock: Int,
@@ -93,6 +130,14 @@ class RowBasedPartitioner(
93130
}
94131
}
95132

133+
/**
134+
* A specialized partitioner that stores all blocks in the same column in just one partition.
135+
*
136+
* @param numPartitions Number of partitions. Should be set as the number of blocks that form
137+
* the columns of the matrix.
138+
* @param rowPerBlock Number of rows that make up each block.
139+
* @param colPerBlock Number of columns that make up each block.
140+
*/
96141
class ColumnBasedPartitioner(
97142
override val numPartitions: Int,
98143
override val rowPerBlock: Int,
@@ -114,39 +159,44 @@ class ColumnBasedPartitioner(
114159
}
115160
}
116161

162+
/**
163+
* Represents a distributed matrix in blocks of local matrices.
164+
*
165+
* @param numRowBlocks Number of blocks that form the rows of this matrix
166+
* @param numColBlocks Number of blocks that form the columns of this matrix
167+
* @param rdd The RDD of BlockPartitions (local matrices) that form this matrix
168+
* @param partitioner A partitioner that specifies how BlockPartitions are stored in the cluster
169+
*/
117170
class BlockMatrix(
118171
val numRowBlocks: Int,
119172
val numColBlocks: Int,
120173
val rdd: RDD[BlockPartition],
121174
val partitioner: BlockMatrixPartitioner) extends DistributedMatrix with Logging {
122175

123-
// We need a key-value pair RDD to partition properly
124-
private var matrixRDD = rdd.map { block =>
125-
partitioner match {
126-
case r: RowBasedPartitioner => (block.blockIdRow, block)
127-
case c: ColumnBasedPartitioner => (block.blockIdCol, block)
128-
case g: GridPartitioner => (block.blockIdRow + numRowBlocks * block.blockIdCol, block)
129-
case _ => throw new IllegalArgumentException("Unrecognized partitioner")
130-
}
131-
}
176+
// A key-value pair RDD is required to partition properly
177+
private var matrixRDD: RDD[(Int, BlockPartition)] = keyBy()
132178

133179
@transient var blockInfo_ : Map[(Int, Int), BlockPartitionInfo] = null
134180

135-
lazy val dims: (Long, Long) = getDim
181+
private lazy val dims: (Long, Long) = getDim
136182

137183
override def numRows(): Long = dims._1
138184
override def numCols(): Long = dims._2
139185

140186
if (partitioner.name.equals("column")) {
141-
require(numColBlocks == partitioner.numPartitions)
187+
require(numColBlocks == partitioner.numPartitions, "The number of column blocks should match" +
188+
" the number of partitions of the column partitioner.")
142189
} else if (partitioner.name.equals("row")) {
143-
require(numRowBlocks == partitioner.numPartitions)
190+
require(numRowBlocks == partitioner.numPartitions, "The number of row blocks should match" +
191+
" the number of partitions of the row partitioner.")
144192
} else if (partitioner.name.equals("grid")) {
145-
require(numRowBlocks * numColBlocks == partitioner.numPartitions)
193+
require(numRowBlocks * numColBlocks == partitioner.numPartitions, "The number of blocks " +
194+
"should match the number of partitions of the grid partitioner.")
146195
} else {
147196
throw new IllegalArgumentException("Unrecognized partitioner.")
148197
}
149198

199+
/* Returns the dimensions of the matrix. */
150200
def getDim: (Long, Long) = {
151201
val bi = getBlockInfo
152202
val xDim = bi.map { x =>
@@ -194,18 +244,20 @@ class BlockMatrix(
194244
}.toMap
195245

196246
blockInfo_ = blockStartRowCols.map{ case ((rowId, colId), (partId, numRow, numCol)) =>
197-
((rowId, colId), new BlockPartitionInfo(partId, rowId, colId, cumulativeRowSum(rowId), numRow,
198-
cumulativeColSum(colId), numCol))
247+
((rowId, colId), new BlockPartitionInfo(partId, rowId, colId, cumulativeRowSum(rowId),
248+
numRow, cumulativeColSum(colId), numCol))
199249
}.toMap
200250
}
201251

252+
/* Returns a map of the information of the blocks that form the distributed matrix. */
202253
def getBlockInfo: Map[(Int, Int), BlockPartitionInfo] = {
203254
if (blockInfo_ == null) {
204255
calculateBlockInfo()
205256
}
206257
blockInfo_
207258
}
208259

260+
/* Returns the Frobenius Norm of the matrix */
209261
def normFro(): Double = {
210262
math.sqrt(rdd.map(lm => lm.mat.values.map(x => math.pow(x, 2)).sum).reduce(_ + _))
211263
}
@@ -222,8 +274,19 @@ class BlockMatrix(
222274
this
223275
}
224276

277+
private def keyBy(part: BlockMatrixPartitioner = partitioner): RDD[(Int, BlockPartition)] = {
278+
rdd.map { block =>
279+
part match {
280+
case r: RowBasedPartitioner => (block.blockIdRow, block)
281+
case c: ColumnBasedPartitioner => (block.blockIdCol, block)
282+
case g: GridPartitioner => (block.blockIdRow + numRowBlocks * block.blockIdCol, block)
283+
case _ => throw new IllegalArgumentException("Unrecognized partitioner")
284+
}
285+
}
286+
}
287+
225288
def repartition(part: BlockMatrixPartitioner = partitioner): DistributedMatrix = {
226-
matrixRDD = matrixRDD.partitionBy(part)
289+
matrixRDD = keyBy(part)
227290
this
228291
}
229292

@@ -259,80 +322,4 @@ class BlockMatrix(
259322
val localMat = collect()
260323
new BDM[Double](localMat.numRows, localMat.numCols, localMat.values)
261324
}
262-
263-
def add(other: DistributedMatrix): DistributedMatrix = {
264-
other match {
265-
// We really need a function to check if two matrices are partitioned similarly
266-
case otherBlocked: BlockMatrix =>
267-
if (checkPartitioning(otherBlocked, OperationNames.add)){
268-
val addedBlocks = rdd.zip(otherBlocked.rdd).map{ case (a, b) =>
269-
val result = a.mat.toBreeze + b.mat.toBreeze
270-
new BlockPartition(a.blockIdRow, a.blockIdCol,
271-
Matrices.fromBreeze(result).asInstanceOf[DenseMatrix])
272-
}
273-
new BlockMatrix(numRowBlocks, numColBlocks, addedBlocks, partitioner)
274-
} else {
275-
throw new SparkException(
276-
"Cannot add matrices with non-matching partitioners")
277-
}
278-
case _ =>
279-
throw new IllegalArgumentException("Cannot add matrices of different types")
280-
}
281-
}
282-
283-
def multiply(other: DistributedMatrix): BlockMatrix = {
284-
other match {
285-
case otherBlocked: BlockMatrix =>
286-
if (checkPartitioning(otherBlocked, OperationNames.multiply)){
287-
288-
val resultPartitioner = new GridPartitioner(numRowBlocks, otherBlocked.numColBlocks,
289-
partitioner.rowPerBlock, otherBlocked.partitioner.colPerBlock)
290-
291-
val multiplyBlocks = matrixRDD.join(otherBlocked.matrixRDD, partitioner).
292-
map { case (key, (mat1, mat2)) =>
293-
val C = mat1.mat multiply mat2.mat
294-
(mat1.blockIdRow + numRowBlocks * mat2.blockIdCol, C.toBreeze)
295-
}.reduceByKey(resultPartitioner, (a, b) => a + b)
296-
297-
val newBlocks = multiplyBlocks.map{ case (index, mat) =>
298-
val colId = index / numRowBlocks
299-
val rowId = index - colId * numRowBlocks
300-
new BlockPartition(rowId, colId, Matrices.fromBreeze(mat).asInstanceOf[DenseMatrix])
301-
}
302-
new BlockMatrix(numRowBlocks, otherBlocked.numColBlocks, newBlocks, resultPartitioner)
303-
} else {
304-
throw new SparkException(
305-
"Cannot multiply matrices with non-matching partitioners")
306-
}
307-
case _ =>
308-
throw new IllegalArgumentException("Cannot add matrices of different types")
309-
}
310-
}
311-
312-
private def checkPartitioning(other: BlockMatrix, operation: Int): Boolean = {
313-
val otherPartitioner = other.partitioner
314-
operation match {
315-
case OperationNames.add =>
316-
partitioner.equals(otherPartitioner)
317-
case OperationNames.multiply =>
318-
partitioner.name == "column" && otherPartitioner.name == "row" &&
319-
partitioner.numPartitions == otherPartitioner.numPartitions &&
320-
partitioner.colPerBlock == otherPartitioner.rowPerBlock &&
321-
numColBlocks == other.numRowBlocks
322-
case _ =>
323-
throw new IllegalArgumentException("Unsupported operation")
324-
}
325-
}
326-
}
327-
328-
/**
329-
* Maintains supported and default block matrix operation names.
330-
*
331-
* Currently supported operations: `add`, `multiply`.
332-
*/
333-
private object OperationNames {
334-
335-
val add: Int = 1
336-
val multiply: Int = 2
337-
338325
}

0 commit comments

Comments
 (0)