Skip to content

Commit eebbdf7

Browse files
committed
preliminary changes addressing code review
1 parent 1a63b20 commit eebbdf7

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ private[mllib] class GridPartitioner(
4242
override val numPartitions: Int) extends Partitioner {
4343

4444
/**
45-
* Returns the index of the partition the SubMatrix belongs to.
45+
* Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
46+
* partitioning.
4647
*
4748
* @param key The key for the SubMatrix. Can be its position in the grid (its column major index)
4849
* or a tuple of three integers that are the final row index after the multiplication,
@@ -51,22 +52,25 @@ private[mllib] class GridPartitioner(
5152
* @return The index of the partition, which the SubMatrix belongs to.
5253
*/
5354
override def getPartition(key: Any): Int = {
55+
val sqrtPartition = math.round(math.sqrt(numPartitions)).toInt
56+
// numPartitions may not be the square of a number, it can even be a prime number
57+
5458
key match {
55-
case (rowIndex: Int, colIndex: Int) =>
56-
Utils.nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
57-
case (rowIndex: Int, innerIndex: Int, colIndex: Int) =>
58-
Utils.nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
59+
case (blockRowIndex: Int, blockColIndex: Int) =>
60+
Utils.nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
61+
case (blockRowIndex: Int, innerIndex: Int, blockColIndex: Int) =>
62+
Utils.nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
5963
case _ =>
60-
throw new IllegalArgumentException("Unrecognized key")
64+
throw new IllegalArgumentException(s"Unrecognized key. key: $key")
6165
}
6266
}
6367

6468
/** Checks whether the partitioners have the same characteristics */
6569
override def equals(obj: Any): Boolean = {
6670
obj match {
6771
case r: GridPartitioner =>
68-
(this.numPartitions == r.numPartitions) && (this.rowsPerBlock == r.rowsPerBlock) &&
69-
(this.colsPerBlock == r.colsPerBlock)
72+
(this.numRowBlocks == r.numRowBlocks) && (this.numColBlocks == r.numColBlocks)
73+
(this.rowsPerBlock == r.rowsPerBlock) && (this.colsPerBlock == r.colsPerBlock)
7074
case _ =>
7175
false
7276
}
@@ -85,7 +89,7 @@ class BlockMatrix(
8589
val numColBlocks: Int,
8690
val rdd: RDD[((Int, Int), Matrix)]) extends DistributedMatrix with Logging {
8791

88-
type SubMatrix = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), matrix)
92+
private type SubMatrix = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), matrix)
8993

9094
/**
9195
* Alternate constructor for BlockMatrix without the input of a partitioner. Will use a Grid

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.mllib.linalg.distributed
1919

2020
import org.scalatest.FunSuite
21-
2221
import breeze.linalg.{DenseMatrix => BDM}
2322

2423
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix}

0 commit comments

Comments
 (0)