@@ -42,7 +42,8 @@ private[mllib] class GridPartitioner(
42
42
override val numPartitions : Int ) extends Partitioner {
43
43
44
44
/**
45
- * Returns the index of the partition the SubMatrix belongs to.
45
+ * Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
46
+ * partitioning.
46
47
*
47
48
* @param key The key for the SubMatrix. Can be its position in the grid (its column major index)
48
49
* or a tuple of three integers that are the final row index after the multiplication,
@@ -51,22 +52,25 @@ private[mllib] class GridPartitioner(
51
52
* @return The index of the partition, which the SubMatrix belongs to.
52
53
*/
53
54
override def getPartition (key : Any ): Int = {
55
+ val sqrtPartition = math.round(math.sqrt(numPartitions)).toInt
56
+ // numPartitions may not be the square of a number, it can even be a prime number
57
+
54
58
key match {
55
- case (rowIndex : Int , colIndex : Int ) =>
56
- Utils .nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
57
- case (rowIndex : Int , innerIndex : Int , colIndex : Int ) =>
58
- Utils .nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
59
+ case (blockRowIndex : Int , blockColIndex : Int ) =>
60
+ Utils .nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
61
+ case (blockRowIndex : Int , innerIndex : Int , blockColIndex : Int ) =>
62
+ Utils .nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
59
63
case _ =>
60
- throw new IllegalArgumentException (" Unrecognized key" )
64
+ throw new IllegalArgumentException (s " Unrecognized key. key: $ key" )
61
65
}
62
66
}
63
67
64
68
/** Checks whether the partitioners have the same characteristics */
65
69
override def equals (obj : Any ): Boolean = {
66
70
obj match {
67
71
case r : GridPartitioner =>
68
- (this .numPartitions == r.numPartitions ) && (this .rowsPerBlock == r.rowsPerBlock) &&
69
- (this .colsPerBlock == r.colsPerBlock)
72
+ (this .numRowBlocks == r.numRowBlocks ) && (this .numColBlocks == r.numColBlocks)
73
+ (this .rowsPerBlock == r.rowsPerBlock) && ( this . colsPerBlock == r.colsPerBlock)
70
74
case _ =>
71
75
false
72
76
}
@@ -85,7 +89,7 @@ class BlockMatrix(
85
89
val numColBlocks : Int ,
86
90
val rdd : RDD [((Int , Int ), Matrix )]) extends DistributedMatrix with Logging {
87
91
88
- type SubMatrix = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), matrix)
92
+ private type SubMatrix = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), matrix)
89
93
90
94
/**
91
95
* Alternate constructor for BlockMatrix without the input of a partitioner. Will use a Grid
0 commit comments