Skip to content

Commit d212a31

Browse files
zhangjiajinmengxr
authored andcommitted
[SPARK-8998] [MLLIB] Distribute PrefixSpan computation for large projected databases
Continuation of work by zhangjiajin Closes apache#7412 Author: zhangjiajin <zhangjiajin@huawei.com> Author: Feynman Liang <fliang@databricks.com> Author: zhang jiajin <zhangjiajin@huawei.com> Closes apache#7783 from feynmanliang/SPARK-8998-improve-distributed and squashes the following commits: a61943d [Feynman Liang] Collect small patterns to local 4ddf479 [Feynman Liang] Parallelize freqItemCounts ad23aa9 [zhang jiajin] Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal 87fa021 [Feynman Liang] Improve extend prefix readability c2caa5c [Feynman Liang] Readability improvements and comments 1235cfc [Feynman Liang] Use Iterable[Array[_]] over Array[Array[_]] for database da0091b [Feynman Liang] Use lists for prefixes to reuse data cb2a4fc [Feynman Liang] Inline code for readability 01c9ae9 [Feynman Liang] Add getters 6e149fa [Feynman Liang] Fix splitPrefixSuffixPairs 64271b3 [zhangjiajin] Modified codes according to comments. d2250b7 [zhangjiajin] remove minPatternsBeforeLocalProcessing, add maxSuffixesBeforeLocalProcessing. b07e20c [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark into CollectEnoughPrefixes 095aa3a [zhangjiajin] Modified the code according to the review comments. baa2885 [zhangjiajin] Modified the code according to the review comments. 6560c69 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixeSpan a8fde87 [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark 4dd1c8a [zhangjiajin] initialize file before rebase. 078d410 [zhangjiajin] fix a scala style error. 22b0ef4 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixSpan. ca9c4c8 [zhangjiajin] Modified the code according to the review comments. 574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization. ba5df34 [zhangjiajin] Fix a Scala style error. 4c60fb3 [zhangjiajin] Fix some Scala style errors. 1dd33ad [zhangjiajin] Modified the code according to the review comments. 89bc368 [zhangjiajin] Fixed a Scala style error. a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala 951fd42 [zhang jiajin] Delete Prefixspan.scala 575995f [zhangjiajin] Modified the code according to the review comments. 91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file.
1 parent c581593 commit d212a31

File tree

3 files changed

+161
-69
lines changed

3 files changed

+161
-69
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
4040
minCount: Long,
4141
maxPatternLength: Int,
4242
prefixes: List[Int],
43-
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
43+
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
4444
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
4545
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
4646
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
@@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
6767
}
6868
}
6969

70-
def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
70+
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
7171
database
7272
.map(getSuffix(prefix, _))
7373
.filter(_.nonEmpty)
@@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
8181
*/
8282
private def getFreqItemAndCounts(
8383
minCount: Long,
84-
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
84+
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
8585
// TODO: use PrimitiveKeyOpenHashMap
8686
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
8787
database.foreach { sequence =>

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 147 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
2022
import org.apache.spark.Logging
2123
import org.apache.spark.annotation.Experimental
2224
import org.apache.spark.rdd.RDD
@@ -43,28 +45,45 @@ class PrefixSpan private (
4345
private var minSupport: Double,
4446
private var maxPatternLength: Int) extends Logging with Serializable {
4547

48+
/**
49+
* The maximum number of items allowed in a projected database before local processing. If a
50+
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
51+
*/
52+
// TODO: make configurable with a better default value, 10000 may be too small
53+
private val maxLocalProjDBSize: Long = 10000
54+
4655
/**
4756
* Constructs a default instance with default parameters
4857
* {minSupport: `0.1`, maxPatternLength: `10`}.
4958
*/
5059
def this() = this(0.1, 10)
5160

61+
/**
62+
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
63+
* frequent).
64+
*/
65+
def getMinSupport: Double = this.minSupport
66+
5267
/**
5368
* Sets the minimal support level (default: `0.1`).
5469
*/
5570
def setMinSupport(minSupport: Double): this.type = {
56-
require(minSupport >= 0 && minSupport <= 1,
57-
"The minimum support value must be between 0 and 1, including 0 and 1.")
71+
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
5872
this.minSupport = minSupport
5973
this
6074
}
6175

76+
/**
77+
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
78+
*/
79+
def getMaxPatternLength: Double = this.maxPatternLength
80+
6281
/**
6382
* Sets maximal pattern length (default: `10`).
6483
*/
6584
def setMaxPatternLength(maxPatternLength: Int): this.type = {
66-
require(maxPatternLength >= 1,
67-
"The maximum pattern length value must be greater than 0.")
85+
// TODO: support unbounded pattern length when maxPatternLength = 0
86+
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
6887
this.maxPatternLength = maxPatternLength
6988
this
7089
}
@@ -78,81 +97,153 @@ class PrefixSpan private (
7897
* the value of pair is the pattern's count.
7998
*/
8099
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
100+
val sc = sequences.sparkContext
101+
81102
if (sequences.getStorageLevel == StorageLevel.NONE) {
82103
logWarning("Input data is not cached.")
83104
}
84-
val minCount = getMinCount(sequences)
85-
val lengthOnePatternsAndCounts =
86-
getFreqItemAndCounts(minCount, sequences).collect()
87-
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
88-
lengthOnePatternsAndCounts.map(_._1), sequences)
89-
val groupedProjectedDatabase = prefixAndProjectedDatabase
90-
.map(x => (x._1.toSeq, x._2))
91-
.groupByKey()
92-
.map(x => (x._1.toArray, x._2.toArray))
93-
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
94-
val lengthOnePatternsAndCountsRdd =
95-
sequences.sparkContext.parallelize(
96-
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
97-
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
98-
allPatterns
105+
106+
// Convert min support to a min number of transactions for this dataset
107+
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
108+
109+
// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
110+
val freqItemCounts = sequences
111+
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
112+
.reduceByKey(_ + _)
113+
.filter(_._2 >= minCount)
114+
.collect()
115+
116+
// Pairs of (length 1 prefix, suffix consisting of frequent items)
117+
val itemSuffixPairs = {
118+
val freqItems = freqItemCounts.map(_._1).toSet
119+
sequences.flatMap { seq =>
120+
val filteredSeq = seq.filter(freqItems.contains(_))
121+
freqItems.flatMap { item =>
122+
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
123+
candidateSuffix match {
124+
case suffix if !suffix.isEmpty => Some((List(item), suffix))
125+
case _ => None
126+
}
127+
}
128+
}
129+
}
130+
131+
// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
132+
// frequent length-one prefixes)
133+
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
134+
135+
// Remaining work to be locally and distributively processed respectfully
136+
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
137+
138+
// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
139+
// projected database sizes <= `maxLocalProjDBSize`)
140+
while (pairsForDistributed.count() != 0) {
141+
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
142+
extendPrefixes(minCount, pairsForDistributed)
143+
pairsForDistributed.unpersist()
144+
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
145+
pairsForDistributed = largerPairsPart
146+
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
147+
pairsForLocal ++= smallerPairsPart
148+
resultsAccumulator ++= nextPatternAndCounts.collect()
149+
}
150+
151+
// Process the small projected databases locally
152+
val remainingResults = getPatternsInLocal(
153+
minCount, sc.parallelize(pairsForLocal, 1).groupByKey())
154+
155+
(sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
156+
.map { case (pattern, count) => (pattern.toArray, count) }
99157
}
100158

159+
101160
/**
102-
* Get the minimum count (sequences count * minSupport).
103-
* @param sequences input data set, contains a set of sequences,
104-
* @return minimum count,
161+
* Partitions the prefix-suffix pairs by projected database size.
162+
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
163+
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
164+
* greater than [[maxLocalProjDBSize]]
105165
*/
106-
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
107-
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
166+
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
167+
: (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
168+
val prefixToSuffixSize = prefixSuffixPairs
169+
.aggregateByKey(0)(
170+
seqOp = { case (count, suffix) => count + suffix.length },
171+
combOp = { _ + _ })
172+
val smallPrefixes = prefixToSuffixSize
173+
.filter(_._2 <= maxLocalProjDBSize)
174+
.keys
175+
.collect()
176+
.toSet
177+
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
178+
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
179+
(small.collect(), large)
108180
}
109181

110182
/**
111-
* Generates frequent items by filtering the input data using minimal count level.
112-
* @param minCount the absolute minimum count
113-
* @param sequences original sequences data
114-
* @return array of item and count pair
183+
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
184+
* and remaining work.
185+
* @param minCount minimum count
186+
* @param prefixSuffixPairs prefix (length N) and suffix pairs,
187+
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
188+
* prefix, corresponding suffix) pairs.
115189
*/
116-
private def getFreqItemAndCounts(
190+
private def extendPrefixes(
117191
minCount: Long,
118-
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
119-
sequences.flatMap(_.distinct.map((_, 1L)))
192+
prefixSuffixPairs: RDD[(List[Int], Array[Int])])
193+
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
194+
195+
// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
196+
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
197+
val prefixItemPairAndCounts = prefixSuffixPairs
198+
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
120199
.reduceByKey(_ + _)
121200
.filter(_._2 >= minCount)
122-
}
123201

124-
/**
125-
* Get the frequent prefixes' projected database.
126-
* @param frequentPrefixes frequent prefixes
127-
* @param sequences sequences data
128-
* @return prefixes and projected database
129-
*/
130-
private def getPrefixAndProjectedDatabase(
131-
frequentPrefixes: Array[Int],
132-
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
133-
val filteredSequences = sequences.map { p =>
134-
p.filter (frequentPrefixes.contains(_) )
135-
}
136-
filteredSequences.flatMap { x =>
137-
frequentPrefixes.map { y =>
138-
val sub = LocalPrefixSpan.getSuffix(y, x)
139-
(Array(y), sub)
140-
}.filter(_._2.nonEmpty)
141-
}
202+
// Map from prefix to set of possible next items from suffix
203+
val prefixToNextItems = prefixItemPairAndCounts
204+
.keys
205+
.groupByKey()
206+
.mapValues(_.toSet)
207+
.collect()
208+
.toMap
209+
210+
211+
// Frequent patterns with length N+1 and their corresponding counts
212+
val extendedPrefixAndCounts = prefixItemPairAndCounts
213+
.map { case ((prefix, item), count) => (item :: prefix, count) }
214+
215+
// Remaining work, all prefixes will have length N+1
216+
val extendedPrefixAndSuffix = prefixSuffixPairs
217+
.filter(x => prefixToNextItems.contains(x._1))
218+
.flatMap { case (prefix, suffix) =>
219+
val frequentNextItems = prefixToNextItems(prefix)
220+
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
221+
frequentNextItems.flatMap { item =>
222+
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
223+
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
224+
case _ => None
225+
}
226+
}
227+
}
228+
229+
(extendedPrefixAndCounts, extendedPrefixAndSuffix)
142230
}
143231

144232
/**
145-
* calculate the patterns in local.
233+
* Calculate the patterns in local.
146234
* @param minCount the absolute minimum count
147-
* @param data patterns and projected sequences data data
235+
* @param data prefixes and projected sequences data data
148236
* @return patterns
149237
*/
150238
private def getPatternsInLocal(
151239
minCount: Long,
152-
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
153-
data.flatMap { case (prefix, projDB) =>
154-
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
155-
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
240+
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
241+
data.flatMap {
242+
case (prefix, projDB) =>
243+
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
244+
.map { case (pattern: List[Int], count: Long) =>
245+
(pattern.reverse, count)
246+
}
156247
}
157248
}
158249
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
4444

4545
val rdd = sc.parallelize(sequences, 2).cache()
4646

47-
def compareResult(
48-
expectedValue: Array[(Array[Int], Long)],
49-
actualValue: Array[(Array[Int], Long)]): Boolean = {
50-
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
51-
actualValue.map(x => (x._1.toSeq, x._2)).toSet
52-
}
53-
5447
val prefixspan = new PrefixSpan()
5548
.setMinSupport(0.33)
5649
.setMaxPatternLength(50)
@@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
7669
(Array(4, 5), 2L),
7770
(Array(5), 3L)
7871
)
79-
assert(compareResult(expectedValue1, result1.collect()))
72+
assert(compareResults(expectedValue1, result1.collect()))
8073

8174
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
8275
val result2 = prefixspan.run(rdd)
@@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
8780
(Array(4), 4L),
8881
(Array(5), 3L)
8982
)
90-
assert(compareResult(expectedValue2, result2.collect()))
83+
assert(compareResults(expectedValue2, result2.collect()))
9184

9285
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
9386
val result3 = prefixspan.run(rdd)
@@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
107100
(Array(4, 5), 2L),
108101
(Array(5), 3L)
109102
)
110-
assert(compareResult(expectedValue3, result3.collect()))
103+
assert(compareResults(expectedValue3, result3.collect()))
104+
}
105+
106+
private def compareResults(
107+
expectedValue: Array[(Array[Int], Long)],
108+
actualValue: Array[(Array[Int], Long)]): Boolean = {
109+
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
110+
actualValue.map(x => (x._1.toSeq, x._2)).toSet
111111
}
112+
112113
}

0 commit comments

Comments
 (0)