@@ -45,30 +45,44 @@ class PrefixSpan private (
45
45
private var minSupport : Double ,
46
46
private var maxPatternLength : Int ) extends Logging with Serializable {
47
47
48
- private val maxProjectedDBSizeBeforeLocalProcessing : Long = 10000
48
+ /**
49
+ * The maximum number of items allowed in a projected database before local processing. If a
50
+ * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
51
+ */
52
+ private val maxLocalProjDBSize : Long = 10000
49
53
50
54
/**
51
55
* Constructs a default instance with default parameters
52
56
* {minSupport: `0.1`, maxPatternLength: `10`}.
53
57
*/
54
58
def this () = this (0.1 , 10 )
55
59
60
+ /**
61
+ * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
62
+ * frequent).
63
+ */
64
+ def getMinSupport (): Double = this .minSupport
65
+
56
66
/**
57
67
* Sets the minimal support level (default: `0.1`).
58
68
*/
59
69
def setMinSupport (minSupport : Double ): this .type = {
60
- require(minSupport >= 0 && minSupport <= 1 ,
61
- " The minimum support value must be between 0 and 1, including 0 and 1." )
70
+ require(minSupport >= 0 && minSupport <= 1 , " The minimum support value must be in [0, 1]." )
62
71
this .minSupport = minSupport
63
72
this
64
73
}
65
74
75
+ /**
76
+ * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
77
+ */
78
+ def getMaxPatternLength (): Double = this .maxPatternLength
79
+
66
80
/**
67
81
* Sets maximal pattern length (default: `10`).
68
82
*/
69
83
def setMaxPatternLength (maxPatternLength : Int ): this .type = {
70
- require( maxPatternLength >= 1 ,
71
- " The maximum pattern length value must be greater than 0." )
84
+ // TODO: support unbounded pattern length when maxPatternLength = 0
85
+ require(maxPatternLength >= 1 , " The maximum pattern length value must be greater than 0." )
72
86
this .maxPatternLength = maxPatternLength
73
87
this
74
88
}
@@ -85,162 +99,145 @@ class PrefixSpan private (
85
99
if (sequences.getStorageLevel == StorageLevel .NONE ) {
86
100
logWarning(" Input data is not cached." )
87
101
}
88
- val minCount = getMinCount(sequences)
89
- val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences)
90
- val prefixSuffixPairs = getPrefixSuffixPairs(
91
- lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
92
- prefixSuffixPairs.persist(StorageLevel .MEMORY_AND_DISK )
93
- var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer (x._1), x._2))
94
- var (smallPrefixSuffixPairs, largePrefixSuffixPairs) =
95
- splitPrefixSuffixPairs(prefixSuffixPairs)
96
- while (largePrefixSuffixPairs.count() != 0 ) {
97
- val (nextPatternAndCounts, nextPrefixSuffixPairs) =
98
- getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs)
99
- largePrefixSuffixPairs.unpersist()
100
- val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs)
101
- largePrefixSuffixPairs = largerPairsPart
102
- largePrefixSuffixPairs.persist(StorageLevel .MEMORY_AND_DISK )
103
- smallPrefixSuffixPairs ++= smallerPairsPart
104
- allPatternAndCounts ++= nextPatternAndCounts
102
+
103
+ // Convert min support to a min number of transactions for this dataset
104
+ val minCount = if (minSupport == 0 ) 0L else math.ceil(sequences.count() * minSupport).toLong
105
+
106
+ // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
107
+ val freqItemCounts = sequences
108
+ .flatMap(seq => seq.distinct.map(item => (item, 1L )))
109
+ .reduceByKey(_ + _)
110
+ .filter(_._2 >= minCount)
111
+
112
+ // Pairs of (length 1 prefix, suffix consisting of frequent items)
113
+ val itemSuffixPairs = {
114
+ val freqItems = freqItemCounts.keys.collect().toSet
115
+ sequences.flatMap { seq =>
116
+ val filteredSeq = seq.filter(freqItems.contains(_))
117
+ freqItems.flatMap { item =>
118
+ val candidateSuffix = LocalPrefixSpan .getSuffix(item, filteredSeq)
119
+ candidateSuffix match {
120
+ case suffix if ! suffix.isEmpty => Some ((List (item), suffix))
121
+ case _ => None
122
+ }
123
+ }
124
+ }
105
125
}
106
- if (smallPrefixSuffixPairs.count() > 0 ) {
107
- val projectedDatabase = smallPrefixSuffixPairs
108
- .map(x => (x._1.toSeq, x._2))
109
- .groupByKey()
110
- .map(x => (x._1.toArray, x._2.toArray))
111
- val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
112
- allPatternAndCounts ++= nextPatternAndCounts
126
+
127
+ // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
128
+ // frequent length-one prefixes)
129
+ var resultsAccumulator = freqItemCounts.map(x => (List (x._1), x._2))
130
+
131
+ // Remaining work to be locally and distributively processed respectfully
132
+ var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
133
+
134
+ // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
135
+ // projected database sizes <= `maxLocalProjDBSize`)
136
+ while (pairsForDistributed.count() != 0 ) {
137
+ val (nextPatternAndCounts, nextPrefixSuffixPairs) =
138
+ extendPrefixes(minCount, pairsForDistributed)
139
+ pairsForDistributed.unpersist()
140
+ val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
141
+ pairsForDistributed = largerPairsPart
142
+ pairsForDistributed.persist(StorageLevel .MEMORY_AND_DISK )
143
+ pairsForLocal ++= smallerPairsPart
144
+ resultsAccumulator ++= nextPatternAndCounts
113
145
}
114
- allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
146
+
147
+ // Process the small projected databases locally
148
+ resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey())
149
+
150
+ resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
115
151
}
116
152
117
153
118
154
/**
119
- * Split prefix suffix pairs to two parts:
120
- * Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and
121
- * Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing
155
+ * Partitions the prefix-suffix pairs by projected database size.
122
156
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
123
- * @return small size prefix suffix pairs and big size prefix suffix pairs
124
- * (RDD[prefix, suffix], RDD[prefix, suffix ])
157
+ * @return prefix- suffix pairs partitioned by whether their projected database size is <= or
158
+ * greater than [[ maxLocalProjDBSize ]]
125
159
*/
126
- private def splitPrefixSuffixPairs (
127
- prefixSuffixPairs : RDD [(ArrayBuffer [Int ], Array [Int ])]):
128
- (RDD [(ArrayBuffer [Int ], Array [Int ])], RDD [(ArrayBuffer [Int ], Array [Int ])]) = {
129
- val suffixSizeMap = prefixSuffixPairs
130
- .map(x => (x._1, x._2.length))
131
- .reduceByKey(_ + _)
132
- .map(x => (x._2 <= maxProjectedDBSizeBeforeLocalProcessing, Set (x._1)))
133
- .reduceByKey(_ ++ _)
134
- .collect
135
- .toMap
136
- val small = if (suffixSizeMap.contains(true )) {
137
- prefixSuffixPairs.filter(x => suffixSizeMap(true ).contains(x._1))
138
- } else {
139
- prefixSuffixPairs.filter(x => false )
140
- }
141
- val large = if (suffixSizeMap.contains(false )) {
142
- prefixSuffixPairs.filter(x => suffixSizeMap(false ).contains(x._1))
143
- } else {
144
- prefixSuffixPairs.filter(x => false )
145
- }
160
+ private def partitionByProjDBSize (prefixSuffixPairs : RDD [(List [Int ], Array [Int ])])
161
+ : (RDD [(List [Int ], Array [Int ])], RDD [(List [Int ], Array [Int ])]) = {
162
+ val prefixToSuffixSize = prefixSuffixPairs
163
+ .aggregateByKey(0 )(
164
+ seqOp = { case (count, suffix) => count + suffix.length },
165
+ combOp = { _ + _ })
166
+ val smallPrefixes = prefixToSuffixSize
167
+ .filter(_._2 <= maxLocalProjDBSize)
168
+ .keys
169
+ .collect()
170
+ .toSet
171
+ val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
172
+ val large = prefixSuffixPairs.filter { case (prefix, _) => ! smallPrefixes.contains(prefix) }
146
173
(small, large)
147
174
}
148
175
149
176
/**
150
- * Get the pattern and counts, and prefix suffix pairs
177
+ * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
178
+ * and remaining work.
151
179
* @param minCount minimum count
152
- * @param prefixSuffixPairs prefix (length n ) and suffix pairs,
153
- * @return pattern ( length n+1) and counts, and prefix ( length n+1) and suffix pairs
154
- * (RDD[pattern, count], RDD[ prefix, suffix ])
180
+ * @param prefixSuffixPairs prefix (length N ) and suffix pairs,
181
+ * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
182
+ * prefix, corresponding suffix) pairs.
155
183
*/
156
- private def getPatternCountsAndPrefixSuffixPairs (
184
+ private def extendPrefixes (
157
185
minCount : Long ,
158
- prefixSuffixPairs : RDD [(ArrayBuffer [Int ], Array [Int ])]):
159
- (RDD [(ArrayBuffer [Int ], Long )], RDD [(ArrayBuffer [Int ], Array [Int ])]) = {
160
- val prefixAndFrequentItemAndCounts = prefixSuffixPairs
186
+ prefixSuffixPairs : RDD [(List [Int ], Array [Int ])])
187
+ : (RDD [(List [Int ], Long )], RDD [(List [Int ], Array [Int ])]) = {
188
+
189
+ // (length N prefix, item from suffix) pairs and their corresponding number of occurrences
190
+ // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
191
+ val prefixItemPairAndCounts = prefixSuffixPairs
161
192
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L )) }
162
193
.reduceByKey(_ + _)
163
194
.filter(_._2 >= minCount)
164
- val patternAndCounts = prefixAndFrequentItemAndCounts
165
- .map { case (( prefix, item), count) => (prefix :+ item, count) }
166
- val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts
195
+
196
+ // Map from prefix to set of possible next items from suffix
197
+ val prefixToNextItems = prefixItemPairAndCounts
167
198
.keys
168
199
.groupByKey()
169
200
.mapValues(_.toSet)
170
201
.collect()
171
202
.toMap
172
- val nextPrefixSuffixPairs = prefixSuffixPairs
173
- .filter(x => prefixToFrequentNextItemsMap.contains(x._1))
174
- .flatMap { case (prefix, suffix) =>
175
- val frequentNextItems = prefixToFrequentNextItemsMap(prefix)
176
- val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
177
- frequentNextItems.flatMap { item =>
178
- val suffix = LocalPrefixSpan .getSuffix(item, filteredSuffix)
179
- if (suffix.isEmpty) None
180
- else Some (prefix :+ item, suffix)
181
- }
182
- }
183
- (patternAndCounts, nextPrefixSuffixPairs)
184
- }
185
203
186
- /**
187
- * Get the minimum count (sequences count * minSupport).
188
- * @param sequences input data set, contains a set of sequences,
189
- * @return minimum count,
190
- */
191
- private def getMinCount (sequences : RDD [Array [Int ]]): Long = {
192
- if (minSupport == 0 ) 0L else math.ceil(sequences.count() * minSupport).toLong
193
- }
194
204
195
- /**
196
- * Generates frequent items by filtering the input data using minimal count level.
197
- * @param minCount the absolute minimum count
198
- * @param sequences original sequences data
199
- * @return array of item and count pair
200
- */
201
- private def getFreqItemAndCounts (
202
- minCount : Long ,
203
- sequences : RDD [Array [Int ]]): RDD [(Int , Long )] = {
204
- sequences.flatMap(_.distinct.map((_, 1L )))
205
- .reduceByKey(_ + _)
206
- .filter(_._2 >= minCount)
207
- }
205
+ // Frequent patterns with length N+1 and their corresponding counts
206
+ val extendedPrefixAndCounts = prefixItemPairAndCounts
207
+ .map { case ((prefix, item), count) => (item :: prefix, count) }
208
208
209
- /**
210
- * Get the frequent prefixes and suffix pairs.
211
- * @param frequentPrefixes frequent prefixes
212
- * @param sequences sequences data
213
- * @return prefixes and suffix pairs.
214
- */
215
- private def getPrefixSuffixPairs (
216
- frequentPrefixes : Array [Int ],
217
- sequences : RDD [Array [Int ]]): RDD [(ArrayBuffer [Int ], Array [Int ])] = {
218
- val filteredSequences = sequences.map { p =>
219
- p.filter (frequentPrefixes.contains(_) )
220
- }
221
- filteredSequences.flatMap { x =>
222
- frequentPrefixes.map { y =>
223
- val sub = LocalPrefixSpan .getSuffix(y, x)
224
- (ArrayBuffer (y), sub)
225
- }.filter(_._2.nonEmpty)
226
- }
209
+ // Remaining work, all prefixes will have length N+1
210
+ val extendedPrefixAndSuffix = prefixSuffixPairs
211
+ .filter(x => prefixToNextItems.contains(x._1))
212
+ .flatMap { case (prefix, suffix) =>
213
+ val frequentNextItems = prefixToNextItems(prefix)
214
+ val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
215
+ frequentNextItems.flatMap { item =>
216
+ LocalPrefixSpan .getSuffix(item, filteredSuffix) match {
217
+ case suffix if ! suffix.isEmpty => Some (item :: prefix, suffix)
218
+ case _ => None
219
+ }
220
+ }
221
+ }
222
+
223
+ (extendedPrefixAndCounts, extendedPrefixAndSuffix)
227
224
}
228
225
229
226
/**
230
- * calculate the patterns in local.
227
+ * Calculate the patterns in local.
231
228
* @param minCount the absolute minimum count
232
229
* @param data prefixes and projected sequences data data
233
230
* @return patterns
234
231
*/
235
232
private def getPatternsInLocal (
236
233
minCount : Long ,
237
- data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(ArrayBuffer [Int ], Long )] = {
234
+ data : RDD [(List [Int ], Iterable [Array [Int ]])]): RDD [(List [Int ], Long )] = {
238
235
data.flatMap {
239
- case (prefix, projDB) =>
240
- LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
241
- .map { case (pattern : List [Int ], count : Long ) =>
242
- (pattern.toArray. reverse.to[ ArrayBuffer ] , count)
243
- }
236
+ case (prefix, projDB) =>
237
+ LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
238
+ .map { case (pattern : List [Int ], count : Long ) =>
239
+ (pattern.reverse, count)
240
+ }
244
241
}
245
242
}
246
243
}
0 commit comments