Skip to content

Commit 2e00cba

Browse files
author
Feynman Liang
committed
Depth first projections
1 parent 70b93e3 commit 2e00cba

File tree

2 files changed

+35
-46
lines changed

2 files changed

+35
-46
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ package org.apache.spark.mllib.fpm
2020
import org.apache.spark.Logging
2121
import org.apache.spark.annotation.Experimental
2222

23-
import scala.collection.mutable.ArrayBuffer
24-
2523
/**
2624
*
2725
* :: Experimental ::
@@ -36,80 +34,71 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
3634
* @param minCount minimum count
3735
* @param maxPatternLength maximum pattern length
3836
* @param prefix prefix
39-
* @param projectedDatabase the projected dabase
37+
* @param database the projected dabase
4038
* @return a set of sequential pattern pairs,
4139
* the key of pair is sequential pattern (a list of items),
4240
* the value of pair is the pattern's count.
4341
*/
4442
def run(
4543
minCount: Long,
4644
maxPatternLength: Int,
47-
prefix: ArrayBuffer[Int],
48-
projectedDatabase: Array[Array[Int]]): Iterator[(Array[Int], Long)] = {
49-
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
50-
val frequentPatternAndCounts = frequentPrefixAndCounts
51-
.map(x => ((prefix :+ x._1).toArray, x._2))
52-
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
53-
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
45+
prefix: List[Int],
46+
database: Iterable[Array[Int]]): Iterator[(Array[Int], Long)] = {
47+
48+
if (database.isEmpty) return Iterator.empty
49+
50+
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
51+
val frequentItems = frequentItemAndCounts.map(_._1)
52+
val frequentPatternAndCounts = frequentItemAndCounts
53+
.map { case (item, count) => ((item :: prefix).reverse.toArray, count) }
5454

55-
if (prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength) {
56-
frequentPatternAndCounts.iterator ++ prefixProjectedDatabases.flatMap {
57-
case (nextPrefix, projDB) => run(minCount, maxPatternLength, nextPrefix, projDB)
55+
val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_)))
56+
57+
if (prefix.length + 1 < maxPatternLength) {
58+
frequentPatternAndCounts ++ frequentItems.flatMap { item =>
59+
val nextProjected = project(filteredProjectedDatabase, item)
60+
run(minCount, maxPatternLength, item :: prefix, nextProjected)
5861
}
5962
} else {
60-
frequentPatternAndCounts.iterator
63+
frequentPatternAndCounts
6164
}
6265
}
6366

6467
/**
65-
* calculate suffix sequence following a prefix in a sequence
66-
* @param prefix prefix
67-
* @param sequence sequence
68+
* Calculate suffix sequence immediately after the first occurrence of an item.
69+
* @param item item to get suffix after
70+
* @param sequence sequence to extract suffix from
6871
* @return suffix sequence
6972
*/
70-
def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
71-
val index = sequence.indexOf(prefix)
73+
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
74+
val index = sequence.indexOf(item)
7275
if (index == -1) {
7376
Array()
7477
} else {
7578
sequence.drop(index + 1)
7679
}
7780
}
7881

82+
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
83+
database
84+
.map(candidateSeq => getSuffix(prefix, candidateSeq))
85+
.filter(_.nonEmpty)
86+
}
87+
7988
/**
8089
* Generates frequent items by filtering the input data using minimal count level.
81-
* @param minCount the absolute minimum count
82-
* @param sequences sequences data
83-
* @return array of item and count pair
90+
* @param minCount the minimum count for an item to be frequent
91+
* @param database database of sequences
92+
* @return item and count pairs
8493
*/
8594
private def getFreqItemAndCounts(
8695
minCount: Long,
87-
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
88-
sequences.flatMap(_.distinct)
96+
database: Iterable[Array[Int]]): Iterator[(Int, Long)] = {
97+
database.flatMap(_.distinct)
8998
.foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) =>
9099
ctr + (item -> (ctr(item) + 1))
91100
}
92101
.filter(_._2 >= minCount)
93-
.toArray
94-
}
95-
96-
/**
97-
* Get the frequent prefixes' projected database.
98-
* @param prefix the frequent prefixes' prefix
99-
* @param frequentPrefixes frequent next prefixes
100-
* @param projDB projected database for given prefix
101-
* @return extensions of prefix by one item and corresponding projected databases
102-
*/
103-
private def getPatternAndProjectedDatabase(
104-
prefix: ArrayBuffer[Int],
105-
frequentPrefixes: Array[Int],
106-
projDB: Array[Array[Int]]): Array[(ArrayBuffer[Int], Array[Array[Int]])] = {
107-
val filteredProjectedDatabase = projDB.map(x => x.filter(frequentPrefixes.contains(_)))
108-
frequentPrefixes.map { nextItem =>
109-
val nextProjDB = filteredProjectedDatabase
110-
.map(candidateSeq => getSuffix(nextItem, candidateSeq))
111-
.filter(_.nonEmpty)
112-
(prefix :+ nextItem, nextProjDB)
113-
}.filter(x => x._2.nonEmpty)
102+
.iterator
114103
}
115104
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class PrefixSpan private (
153153
minCount: Long,
154154
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
155155
data.flatMap { case (prefix, projDB) =>
156-
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.to[ArrayBuffer], projDB)
156+
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
157157
}
158158
}
159159
}

0 commit comments

Comments
 (0)