Skip to content

Commit 1235cfc

Browse files
author
Feynman Liang
committed
Use Iterable[Array[_]] over Array[Array[_]] for database
1 parent da0091b commit 1235cfc

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
4040
minCount: Long,
4141
maxPatternLength: Int,
4242
prefixes: List[Int],
43-
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
43+
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
4444
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
4545
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
4646
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
@@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
6767
}
6868
}
6969

70-
def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
70+
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
7171
database
7272
.map(getSuffix(prefix, _))
7373
.filter(_.nonEmpty)
@@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
8181
*/
8282
private def getFreqItemAndCounts(
8383
minCount: Long,
84-
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
84+
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
8585
// TODO: use PrimitiveKeyOpenHashMap
8686
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
8787
database.foreach { sequence =>

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ class PrefixSpan private (
4545
private var minSupport: Double,
4646
private var maxPatternLength: Int) extends Logging with Serializable {
4747

48-
private val maxProjectedDBSizeBeforeLocalProcessing: Long = 10000
48+
/**
49+
* The maximum number of items allowed in a projected database before local processing. If a
50+
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
51+
*/
52+
private val maxLocalProjDBSize: Long = 10000
4953

5054
/**
5155
* Constructs a default instance with default parameters
@@ -63,8 +67,7 @@ class PrefixSpan private (
6367
* Sets the minimal support level (default: `0.1`).
6468
*/
6569
def setMinSupport(minSupport: Double): this.type = {
66-
require(minSupport >= 0 && minSupport <= 1,
67-
"The minimum support value must be in [0, 1].")
70+
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
6871
this.minSupport = minSupport
6972
this
7073
}
@@ -79,8 +82,7 @@ class PrefixSpan private (
7982
*/
8083
def setMaxPatternLength(maxPatternLength: Int): this.type = {
8184
// TODO: support unbounded pattern length when maxPatternLength = 0
82-
require(maxPatternLength >= 1,
83-
"The maximum pattern length value must be greater than 0.")
85+
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
8486
this.maxPatternLength = maxPatternLength
8587
this
8688
}
@@ -119,13 +121,13 @@ class PrefixSpan private (
119121
}.filter(_._2.nonEmpty)
120122
}
121123
}
122-
var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = splitPrefixSuffixPairs(prefixSuffixPairs)
124+
var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = partitionByProjDBSize(prefixSuffixPairs)
123125

124126
while (largePrefixSuffixPairs.count() != 0) {
125127
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
126128
getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs)
127129
largePrefixSuffixPairs.unpersist()
128-
val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs)
130+
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
129131
largePrefixSuffixPairs = largerPairsPart
130132
largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK)
131133
smallPrefixSuffixPairs ++= smallerPairsPart
@@ -136,7 +138,6 @@ class PrefixSpan private (
136138
val projectedDatabase = smallPrefixSuffixPairs
137139
// TODO aggregateByKey
138140
.groupByKey()
139-
.mapValues(_.toArray)
140141
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
141142
allPatternAndCounts ++= nextPatternAndCounts
142143
}
@@ -145,23 +146,21 @@ class PrefixSpan private (
145146

146147

147148
/**
148-
* Split prefix suffix pairs to two parts:
149-
* Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and
150-
* Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing
149+
* Partitions the prefix-suffix pairs by projected database size.
150+
*
151151
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
152-
* @return small size prefix suffix pairs and big size prefix suffix pairs
153-
* (RDD[prefix, suffix], RDD[prefix, suffix ])
152+
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
153+
* greater than [[maxLocalProjDBSize]]
154154
*/
155-
private def splitPrefixSuffixPairs(
156-
prefixSuffixPairs: RDD[(List[Int], Array[Int])]):
157-
(RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
155+
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
156+
: (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
158157
val prefixToSuffixSize = prefixSuffixPairs
159158
.aggregateByKey(0)(
160159
seqOp = { case (count, suffix) => count + suffix.length },
161160
combOp = { _ + _ })
162161
val smallPrefixes = prefixToSuffixSize
163-
.filter(_._2 <= maxProjectedDBSizeBeforeLocalProcessing)
164-
.map(_._1)
162+
.filter(_._2 <= maxLocalProjDBSize)
163+
.keys
165164
.collect()
166165
.toSet
167166
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
@@ -214,7 +213,7 @@ class PrefixSpan private (
214213
*/
215214
private def getPatternsInLocal(
216215
minCount: Long,
217-
data: RDD[(List[Int], Array[Array[Int]])]): RDD[(List[Int], Long)] = {
216+
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
218217
data.flatMap {
219218
case (prefix, projDB) =>
220219
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)

0 commit comments

Comments
 (0)