17
17
18
18
package org .apache .spark .mllib .fpm
19
19
20
+ import scala .collection .mutable .ArrayBuffer
21
+
20
22
import org .apache .spark .Logging
21
23
import org .apache .spark .annotation .Experimental
22
24
import org .apache .spark .rdd .RDD
@@ -43,28 +45,45 @@ class PrefixSpan private (
43
45
private var minSupport : Double ,
44
46
private var maxPatternLength : Int ) extends Logging with Serializable {
45
47
48
+ /**
49
+ * The maximum number of items allowed in a projected database before local processing. If a
50
+ * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
51
+ */
52
+ // TODO: make configurable with a better default value, 10000 may be too small
53
+ private val maxLocalProjDBSize : Long = 10000
54
+
46
55
/**
47
56
* Constructs a default instance with default parameters
48
57
* {minSupport: `0.1`, maxPatternLength: `10`}.
49
58
*/
50
59
def this () = this (0.1 , 10 )
51
60
61
+ /**
62
+ * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
63
+ * frequent).
64
+ */
65
+ def getMinSupport : Double = this .minSupport
66
+
52
67
/**
53
68
* Sets the minimal support level (default: `0.1`).
54
69
*/
55
70
def setMinSupport (minSupport : Double ): this .type = {
56
- require(minSupport >= 0 && minSupport <= 1 ,
57
- " The minimum support value must be between 0 and 1, including 0 and 1." )
71
+ require(minSupport >= 0 && minSupport <= 1 , " The minimum support value must be in [0, 1]." )
58
72
this .minSupport = minSupport
59
73
this
60
74
}
61
75
76
+ /**
77
+ * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
78
+ */
79
+ def getMaxPatternLength : Double = this .maxPatternLength
80
+
62
81
/**
63
82
* Sets maximal pattern length (default: `10`).
64
83
*/
65
84
def setMaxPatternLength (maxPatternLength : Int ): this .type = {
66
- require( maxPatternLength >= 1 ,
67
- " The maximum pattern length value must be greater than 0." )
85
+ // TODO: support unbounded pattern length when maxPatternLength = 0
86
+ require(maxPatternLength >= 1 , " The maximum pattern length value must be greater than 0." )
68
87
this .maxPatternLength = maxPatternLength
69
88
this
70
89
}
@@ -78,81 +97,153 @@ class PrefixSpan private (
78
97
* the value of pair is the pattern's count.
79
98
*/
80
99
def run (sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Long )] = {
100
+ val sc = sequences.sparkContext
101
+
81
102
if (sequences.getStorageLevel == StorageLevel .NONE ) {
82
103
logWarning(" Input data is not cached." )
83
104
}
84
- val minCount = getMinCount(sequences)
85
- val lengthOnePatternsAndCounts =
86
- getFreqItemAndCounts(minCount, sequences).collect()
87
- val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
88
- lengthOnePatternsAndCounts.map(_._1), sequences)
89
- val groupedProjectedDatabase = prefixAndProjectedDatabase
90
- .map(x => (x._1.toSeq, x._2))
91
- .groupByKey()
92
- .map(x => (x._1.toArray, x._2.toArray))
93
- val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
94
- val lengthOnePatternsAndCountsRdd =
95
- sequences.sparkContext.parallelize(
96
- lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)))
97
- val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
98
- allPatterns
105
+
106
+ // Convert min support to a min number of transactions for this dataset
107
+ val minCount = if (minSupport == 0 ) 0L else math.ceil(sequences.count() * minSupport).toLong
108
+
109
+ // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
110
+ val freqItemCounts = sequences
111
+ .flatMap(seq => seq.distinct.map(item => (item, 1L )))
112
+ .reduceByKey(_ + _)
113
+ .filter(_._2 >= minCount)
114
+ .collect()
115
+
116
+ // Pairs of (length 1 prefix, suffix consisting of frequent items)
117
+ val itemSuffixPairs = {
118
+ val freqItems = freqItemCounts.map(_._1).toSet
119
+ sequences.flatMap { seq =>
120
+ val filteredSeq = seq.filter(freqItems.contains(_))
121
+ freqItems.flatMap { item =>
122
+ val candidateSuffix = LocalPrefixSpan .getSuffix(item, filteredSeq)
123
+ candidateSuffix match {
124
+ case suffix if ! suffix.isEmpty => Some ((List (item), suffix))
125
+ case _ => None
126
+ }
127
+ }
128
+ }
129
+ }
130
+
131
+ // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
132
+ // frequent length-one prefixes)
133
+ var resultsAccumulator = freqItemCounts.map(x => (List (x._1), x._2))
134
+
135
+ // Remaining work to be locally and distributively processed respectfully
136
+ var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
137
+
138
+ // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
139
+ // projected database sizes <= `maxLocalProjDBSize`)
140
+ while (pairsForDistributed.count() != 0 ) {
141
+ val (nextPatternAndCounts, nextPrefixSuffixPairs) =
142
+ extendPrefixes(minCount, pairsForDistributed)
143
+ pairsForDistributed.unpersist()
144
+ val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
145
+ pairsForDistributed = largerPairsPart
146
+ pairsForDistributed.persist(StorageLevel .MEMORY_AND_DISK )
147
+ pairsForLocal ++= smallerPairsPart
148
+ resultsAccumulator ++= nextPatternAndCounts.collect()
149
+ }
150
+
151
+ // Process the small projected databases locally
152
+ val remainingResults = getPatternsInLocal(
153
+ minCount, sc.parallelize(pairsForLocal, 1 ).groupByKey())
154
+
155
+ (sc.parallelize(resultsAccumulator, 1 ) ++ remainingResults)
156
+ .map { case (pattern, count) => (pattern.toArray, count) }
99
157
}
100
158
159
+
101
160
/**
102
- * Get the minimum count (sequences count * minSupport).
103
- * @param sequences input data set, contains a set of sequences,
104
- * @return minimum count,
161
+ * Partitions the prefix-suffix pairs by projected database size.
162
+ * @param prefixSuffixPairs prefix (length n) and suffix pairs,
163
+ * @return prefix-suffix pairs partitioned by whether their projected database size is <= or
164
+ * greater than [[maxLocalProjDBSize ]]
105
165
*/
106
- private def getMinCount (sequences : RDD [Array [Int ]]): Long = {
107
- if (minSupport == 0 ) 0L else math.ceil(sequences.count() * minSupport).toLong
166
+ private def partitionByProjDBSize (prefixSuffixPairs : RDD [(List [Int ], Array [Int ])])
167
+ : (Array [(List [Int ], Array [Int ])], RDD [(List [Int ], Array [Int ])]) = {
168
+ val prefixToSuffixSize = prefixSuffixPairs
169
+ .aggregateByKey(0 )(
170
+ seqOp = { case (count, suffix) => count + suffix.length },
171
+ combOp = { _ + _ })
172
+ val smallPrefixes = prefixToSuffixSize
173
+ .filter(_._2 <= maxLocalProjDBSize)
174
+ .keys
175
+ .collect()
176
+ .toSet
177
+ val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
178
+ val large = prefixSuffixPairs.filter { case (prefix, _) => ! smallPrefixes.contains(prefix) }
179
+ (small.collect(), large)
108
180
}
109
181
110
182
/**
111
- * Generates frequent items by filtering the input data using minimal count level.
112
- * @param minCount the absolute minimum count
113
- * @param sequences original sequences data
114
- * @return array of item and count pair
183
+ * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
184
+ * and remaining work.
185
+ * @param minCount minimum count
186
+ * @param prefixSuffixPairs prefix (length N) and suffix pairs,
187
+ * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
188
+ * prefix, corresponding suffix) pairs.
115
189
*/
116
- private def getFreqItemAndCounts (
190
+ private def extendPrefixes (
117
191
minCount : Long ,
118
- sequences : RDD [Array [Int ]]): RDD [(Int , Long )] = {
119
- sequences.flatMap(_.distinct.map((_, 1L )))
192
+ prefixSuffixPairs : RDD [(List [Int ], Array [Int ])])
193
+ : (RDD [(List [Int ], Long )], RDD [(List [Int ], Array [Int ])]) = {
194
+
195
+ // (length N prefix, item from suffix) pairs and their corresponding number of occurrences
196
+ // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
197
+ val prefixItemPairAndCounts = prefixSuffixPairs
198
+ .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L )) }
120
199
.reduceByKey(_ + _)
121
200
.filter(_._2 >= minCount)
122
- }
123
201
124
- /**
125
- * Get the frequent prefixes' projected database.
126
- * @param frequentPrefixes frequent prefixes
127
- * @param sequences sequences data
128
- * @return prefixes and projected database
129
- */
130
- private def getPrefixAndProjectedDatabase (
131
- frequentPrefixes : Array [Int ],
132
- sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Array [Int ])] = {
133
- val filteredSequences = sequences.map { p =>
134
- p.filter (frequentPrefixes.contains(_) )
135
- }
136
- filteredSequences.flatMap { x =>
137
- frequentPrefixes.map { y =>
138
- val sub = LocalPrefixSpan .getSuffix(y, x)
139
- (Array (y), sub)
140
- }.filter(_._2.nonEmpty)
141
- }
202
+ // Map from prefix to set of possible next items from suffix
203
+ val prefixToNextItems = prefixItemPairAndCounts
204
+ .keys
205
+ .groupByKey()
206
+ .mapValues(_.toSet)
207
+ .collect()
208
+ .toMap
209
+
210
+
211
+ // Frequent patterns with length N+1 and their corresponding counts
212
+ val extendedPrefixAndCounts = prefixItemPairAndCounts
213
+ .map { case ((prefix, item), count) => (item :: prefix, count) }
214
+
215
+ // Remaining work, all prefixes will have length N+1
216
+ val extendedPrefixAndSuffix = prefixSuffixPairs
217
+ .filter(x => prefixToNextItems.contains(x._1))
218
+ .flatMap { case (prefix, suffix) =>
219
+ val frequentNextItems = prefixToNextItems(prefix)
220
+ val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
221
+ frequentNextItems.flatMap { item =>
222
+ LocalPrefixSpan .getSuffix(item, filteredSuffix) match {
223
+ case suffix if ! suffix.isEmpty => Some (item :: prefix, suffix)
224
+ case _ => None
225
+ }
226
+ }
227
+ }
228
+
229
+ (extendedPrefixAndCounts, extendedPrefixAndSuffix)
142
230
}
143
231
144
232
/**
145
- * calculate the patterns in local.
233
+ * Calculate the patterns in local.
146
234
* @param minCount the absolute minimum count
147
- * @param data patterns and projected sequences data data
235
+ * @param data prefixes and projected sequences data data
148
236
* @return patterns
149
237
*/
150
238
private def getPatternsInLocal (
151
239
minCount : Long ,
152
- data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(Array [Int ], Long )] = {
153
- data.flatMap { case (prefix, projDB) =>
154
- LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList, projDB)
155
- .map { case (pattern : List [Int ], count : Long ) => (pattern.toArray.reverse, count) }
240
+ data : RDD [(List [Int ], Iterable [Array [Int ]])]): RDD [(List [Int ], Long )] = {
241
+ data.flatMap {
242
+ case (prefix, projDB) =>
243
+ LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
244
+ .map { case (pattern : List [Int ], count : Long ) =>
245
+ (pattern.reverse, count)
246
+ }
156
247
}
157
248
}
158
249
}
0 commit comments