Skip to content

Commit ad23aa9

Browse files
committed
Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal
[Spark-8998]Collect Enough Prefixes Improvements
2 parents 64271b3 + 87fa021 commit ad23aa9

File tree

3 files changed

+136
-138
lines changed

3 files changed

+136
-138
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
4040
minCount: Long,
4141
maxPatternLength: Int,
4242
prefixes: List[Int],
43-
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
43+
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
4444
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
4545
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
4646
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
@@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
6767
}
6868
}
6969

70-
def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
70+
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
7171
database
7272
.map(getSuffix(prefix, _))
7373
.filter(_.nonEmpty)
@@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
8181
*/
8282
private def getFreqItemAndCounts(
8383
minCount: Long,
84-
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
84+
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
8585
// TODO: use PrimitiveKeyOpenHashMap
8686
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
8787
database.foreach { sequence =>

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 122 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,44 @@ class PrefixSpan private (
4545
private var minSupport: Double,
4646
private var maxPatternLength: Int) extends Logging with Serializable {
4747

48-
private val maxProjectedDBSizeBeforeLocalProcessing: Long = 10000
48+
/**
49+
* The maximum number of items allowed in a projected database before local processing. If a
50+
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
51+
*/
52+
private val maxLocalProjDBSize: Long = 10000
4953

5054
/**
5155
* Constructs a default instance with default parameters
5256
* {minSupport: `0.1`, maxPatternLength: `10`}.
5357
*/
5458
def this() = this(0.1, 10)
5559

60+
/**
61+
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
62+
* frequent).
63+
*/
64+
def getMinSupport(): Double = this.minSupport
65+
5666
/**
5767
* Sets the minimal support level (default: `0.1`).
5868
*/
5969
def setMinSupport(minSupport: Double): this.type = {
60-
require(minSupport >= 0 && minSupport <= 1,
61-
"The minimum support value must be between 0 and 1, including 0 and 1.")
70+
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
6271
this.minSupport = minSupport
6372
this
6473
}
6574

75+
/**
76+
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
77+
*/
78+
def getMaxPatternLength(): Double = this.maxPatternLength
79+
6680
/**
6781
* Sets maximal pattern length (default: `10`).
6882
*/
6983
def setMaxPatternLength(maxPatternLength: Int): this.type = {
70-
require(maxPatternLength >= 1,
71-
"The maximum pattern length value must be greater than 0.")
84+
// TODO: support unbounded pattern length when maxPatternLength = 0
85+
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
7286
this.maxPatternLength = maxPatternLength
7387
this
7488
}
@@ -85,162 +99,145 @@ class PrefixSpan private (
8599
if (sequences.getStorageLevel == StorageLevel.NONE) {
86100
logWarning("Input data is not cached.")
87101
}
88-
val minCount = getMinCount(sequences)
89-
val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences)
90-
val prefixSuffixPairs = getPrefixSuffixPairs(
91-
lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
92-
prefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK)
93-
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2))
94-
var (smallPrefixSuffixPairs, largePrefixSuffixPairs) =
95-
splitPrefixSuffixPairs(prefixSuffixPairs)
96-
while (largePrefixSuffixPairs.count() != 0) {
97-
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
98-
getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs)
99-
largePrefixSuffixPairs.unpersist()
100-
val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs)
101-
largePrefixSuffixPairs = largerPairsPart
102-
largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK)
103-
smallPrefixSuffixPairs ++= smallerPairsPart
104-
allPatternAndCounts ++= nextPatternAndCounts
102+
103+
// Convert min support to a min number of transactions for this dataset
104+
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
105+
106+
// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
107+
val freqItemCounts = sequences
108+
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
109+
.reduceByKey(_ + _)
110+
.filter(_._2 >= minCount)
111+
112+
// Pairs of (length 1 prefix, suffix consisting of frequent items)
113+
val itemSuffixPairs = {
114+
val freqItems = freqItemCounts.keys.collect().toSet
115+
sequences.flatMap { seq =>
116+
val filteredSeq = seq.filter(freqItems.contains(_))
117+
freqItems.flatMap { item =>
118+
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
119+
candidateSuffix match {
120+
case suffix if !suffix.isEmpty => Some((List(item), suffix))
121+
case _ => None
122+
}
123+
}
124+
}
105125
}
106-
if (smallPrefixSuffixPairs.count() > 0) {
107-
val projectedDatabase = smallPrefixSuffixPairs
108-
.map(x => (x._1.toSeq, x._2))
109-
.groupByKey()
110-
.map(x => (x._1.toArray, x._2.toArray))
111-
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
112-
allPatternAndCounts ++= nextPatternAndCounts
126+
127+
// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
128+
// frequent length-one prefixes)
129+
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
130+
131+
// Remaining work to be locally and distributively processed respectfully
132+
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
133+
134+
// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
135+
// projected database sizes <= `maxLocalProjDBSize`)
136+
while (pairsForDistributed.count() != 0) {
137+
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
138+
extendPrefixes(minCount, pairsForDistributed)
139+
pairsForDistributed.unpersist()
140+
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
141+
pairsForDistributed = largerPairsPart
142+
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
143+
pairsForLocal ++= smallerPairsPart
144+
resultsAccumulator ++= nextPatternAndCounts
113145
}
114-
allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
146+
147+
// Process the small projected databases locally
148+
resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey())
149+
150+
resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
115151
}
116152

117153

118154
/**
119-
* Split prefix suffix pairs to two parts:
120-
* Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and
121-
* Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing
155+
* Partitions the prefix-suffix pairs by projected database size.
122156
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
123-
* @return small size prefix suffix pairs and big size prefix suffix pairs
124-
* (RDD[prefix, suffix], RDD[prefix, suffix ])
157+
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
158+
* greater than [[maxLocalProjDBSize]]
125159
*/
126-
private def splitPrefixSuffixPairs(
127-
prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]):
128-
(RDD[(ArrayBuffer[Int], Array[Int])], RDD[(ArrayBuffer[Int], Array[Int])]) = {
129-
val suffixSizeMap = prefixSuffixPairs
130-
.map(x => (x._1, x._2.length))
131-
.reduceByKey(_ + _)
132-
.map(x => (x._2 <= maxProjectedDBSizeBeforeLocalProcessing, Set(x._1)))
133-
.reduceByKey(_ ++ _)
134-
.collect
135-
.toMap
136-
val small = if (suffixSizeMap.contains(true)) {
137-
prefixSuffixPairs.filter(x => suffixSizeMap(true).contains(x._1))
138-
} else {
139-
prefixSuffixPairs.filter(x => false)
140-
}
141-
val large = if (suffixSizeMap.contains(false)) {
142-
prefixSuffixPairs.filter(x => suffixSizeMap(false).contains(x._1))
143-
} else {
144-
prefixSuffixPairs.filter(x => false)
145-
}
160+
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
161+
: (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
162+
val prefixToSuffixSize = prefixSuffixPairs
163+
.aggregateByKey(0)(
164+
seqOp = { case (count, suffix) => count + suffix.length },
165+
combOp = { _ + _ })
166+
val smallPrefixes = prefixToSuffixSize
167+
.filter(_._2 <= maxLocalProjDBSize)
168+
.keys
169+
.collect()
170+
.toSet
171+
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
172+
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
146173
(small, large)
147174
}
148175

149176
/**
150-
* Get the pattern and counts, and prefix suffix pairs
177+
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
178+
* and remaining work.
151179
* @param minCount minimum count
152-
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
153-
* @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs
154-
* (RDD[pattern, count], RDD[prefix, suffix ])
180+
* @param prefixSuffixPairs prefix (length N) and suffix pairs,
181+
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
182+
* prefix, corresponding suffix) pairs.
155183
*/
156-
private def getPatternCountsAndPrefixSuffixPairs(
184+
private def extendPrefixes(
157185
minCount: Long,
158-
prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]):
159-
(RDD[(ArrayBuffer[Int], Long)], RDD[(ArrayBuffer[Int], Array[Int])]) = {
160-
val prefixAndFrequentItemAndCounts = prefixSuffixPairs
186+
prefixSuffixPairs: RDD[(List[Int], Array[Int])])
187+
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
188+
189+
// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
190+
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
191+
val prefixItemPairAndCounts = prefixSuffixPairs
161192
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
162193
.reduceByKey(_ + _)
163194
.filter(_._2 >= minCount)
164-
val patternAndCounts = prefixAndFrequentItemAndCounts
165-
.map { case ((prefix, item), count) => (prefix :+ item, count) }
166-
val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts
195+
196+
// Map from prefix to set of possible next items from suffix
197+
val prefixToNextItems = prefixItemPairAndCounts
167198
.keys
168199
.groupByKey()
169200
.mapValues(_.toSet)
170201
.collect()
171202
.toMap
172-
val nextPrefixSuffixPairs = prefixSuffixPairs
173-
.filter(x => prefixToFrequentNextItemsMap.contains(x._1))
174-
.flatMap { case (prefix, suffix) =>
175-
val frequentNextItems = prefixToFrequentNextItemsMap(prefix)
176-
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
177-
frequentNextItems.flatMap { item =>
178-
val suffix = LocalPrefixSpan.getSuffix(item, filteredSuffix)
179-
if (suffix.isEmpty) None
180-
else Some(prefix :+ item, suffix)
181-
}
182-
}
183-
(patternAndCounts, nextPrefixSuffixPairs)
184-
}
185203

186-
/**
187-
* Get the minimum count (sequences count * minSupport).
188-
* @param sequences input data set, contains a set of sequences,
189-
* @return minimum count,
190-
*/
191-
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
192-
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
193-
}
194204

195-
/**
196-
* Generates frequent items by filtering the input data using minimal count level.
197-
* @param minCount the absolute minimum count
198-
* @param sequences original sequences data
199-
* @return array of item and count pair
200-
*/
201-
private def getFreqItemAndCounts(
202-
minCount: Long,
203-
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
204-
sequences.flatMap(_.distinct.map((_, 1L)))
205-
.reduceByKey(_ + _)
206-
.filter(_._2 >= minCount)
207-
}
205+
// Frequent patterns with length N+1 and their corresponding counts
206+
val extendedPrefixAndCounts = prefixItemPairAndCounts
207+
.map { case ((prefix, item), count) => (item :: prefix, count) }
208208

209-
/**
210-
* Get the frequent prefixes and suffix pairs.
211-
* @param frequentPrefixes frequent prefixes
212-
* @param sequences sequences data
213-
* @return prefixes and suffix pairs.
214-
*/
215-
private def getPrefixSuffixPairs(
216-
frequentPrefixes: Array[Int],
217-
sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = {
218-
val filteredSequences = sequences.map { p =>
219-
p.filter (frequentPrefixes.contains(_) )
220-
}
221-
filteredSequences.flatMap { x =>
222-
frequentPrefixes.map { y =>
223-
val sub = LocalPrefixSpan.getSuffix(y, x)
224-
(ArrayBuffer(y), sub)
225-
}.filter(_._2.nonEmpty)
226-
}
209+
// Remaining work, all prefixes will have length N+1
210+
val extendedPrefixAndSuffix = prefixSuffixPairs
211+
.filter(x => prefixToNextItems.contains(x._1))
212+
.flatMap { case (prefix, suffix) =>
213+
val frequentNextItems = prefixToNextItems(prefix)
214+
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
215+
frequentNextItems.flatMap { item =>
216+
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
217+
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
218+
case _ => None
219+
}
220+
}
221+
}
222+
223+
(extendedPrefixAndCounts, extendedPrefixAndSuffix)
227224
}
228225

229226
/**
230-
* calculate the patterns in local.
227+
* Calculate the patterns in local.
231228
* @param minCount the absolute minimum count
232229
* @param data prefixes and projected sequences data data
233230
* @return patterns
234231
*/
235232
private def getPatternsInLocal(
236233
minCount: Long,
237-
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = {
234+
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
238235
data.flatMap {
239-
case (prefix, projDB) =>
240-
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
241-
.map { case (pattern: List[Int], count: Long) =>
242-
(pattern.toArray.reverse.to[ArrayBuffer], count)
243-
}
236+
case (prefix, projDB) =>
237+
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
238+
.map { case (pattern: List[Int], count: Long) =>
239+
(pattern.reverse, count)
240+
}
244241
}
245242
}
246243
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
4444

4545
val rdd = sc.parallelize(sequences, 2).cache()
4646

47-
def compareResult(
48-
expectedValue: Array[(Array[Int], Long)],
49-
actualValue: Array[(Array[Int], Long)]): Boolean = {
50-
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
51-
actualValue.map(x => (x._1.toSeq, x._2)).toSet
52-
}
53-
5447
val prefixspan = new PrefixSpan()
5548
.setMinSupport(0.33)
5649
.setMaxPatternLength(50)
@@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
7669
(Array(4, 5), 2L),
7770
(Array(5), 3L)
7871
)
79-
assert(compareResult(expectedValue1, result1.collect()))
72+
assert(compareResults(expectedValue1, result1.collect()))
8073

8174
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
8275
val result2 = prefixspan.run(rdd)
@@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
8780
(Array(4), 4L),
8881
(Array(5), 3L)
8982
)
90-
assert(compareResult(expectedValue2, result2.collect()))
83+
assert(compareResults(expectedValue2, result2.collect()))
9184

9285
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
9386
val result3 = prefixspan.run(rdd)
@@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
107100
(Array(4, 5), 2L),
108101
(Array(5), 3L)
109102
)
110-
assert(compareResult(expectedValue3, result3.collect()))
103+
assert(compareResults(expectedValue3, result3.collect()))
104+
}
105+
106+
private def compareResults(
107+
expectedValue: Array[(Array[Int], Long)],
108+
actualValue: Array[(Array[Int], Long)]): Boolean = {
109+
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
110+
actualValue.map(x => (x._1.toSeq, x._2)).toSet
111111
}
112+
112113
}

0 commit comments

Comments
 (0)