Skip to content

Commit eb3e4ca

Browse files
author
Jacky Li
committed
add FPGrowth
1 parent 03df2b6 commit eb3e4ca

File tree

3 files changed

+304
-0
lines changed

3 files changed

+304
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm
19+
20+
import org.apache.spark.Logging
21+
import org.apache.spark.SparkContext._
22+
import org.apache.spark.broadcast._
23+
import org.apache.spark.rdd.RDD
24+
25+
import scala.collection.mutable.{ArrayBuffer, Map}
26+
27+
/**
28+
* This class implements Parallel FPGrowth algorithm to do frequent pattern matching on input data.
29+
* Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an
30+
* independent group of mining tasks. More detail of this algorithm can be found at
31+
* http://infolab.stanford.edu/~echang/recsys08-69.pdf
32+
*/
33+
class FPGrowth private(private var minSupport: Double) extends Logging with Serializable {
34+
35+
/**
36+
* Constructs a FPGrowth instance with default parameters:
37+
* {minSupport: 0.5}
38+
*/
39+
def this() = this(0.5)
40+
41+
/**
42+
* set the minimal support level, default is 0.5
43+
* @param minSupport minimal support level
44+
*/
45+
def setMinSupport(minSupport: Double): this.type = {
46+
this.minSupport = minSupport
47+
this
48+
}
49+
50+
/**
51+
* Compute a FPGrowth Model that contains frequent pattern result.
52+
* @param data input data set
53+
* @return FPGrowth Model
54+
*/
55+
def run(data: RDD[Array[String]]): FPGrowthModel = {
56+
val model = runAlgorithm(data)
57+
model
58+
}
59+
60+
/**
61+
* Implementation of PFP.
62+
*/
63+
private def runAlgorithm(data: RDD[Array[String]]): FPGrowthModel = {
64+
val count = data.count()
65+
val minCount = minSupport * count
66+
val single = generateSingleItem(data, minCount)
67+
val combinations = generateCombinations(data, minCount, single)
68+
new FPGrowthModel(single ++ combinations)
69+
}
70+
71+
/**
72+
* Generate single item pattern by filtering the input data using minimal support level
73+
*/
74+
private def generateSingleItem(
75+
data: RDD[Array[String]],
76+
minCount: Double): Array[(String, Int)] = {
77+
data.flatMap(v => v)
78+
.map(v => (v, 1))
79+
.reduceByKey(_ + _)
80+
.filter(_._2 >= minCount)
81+
.collect()
82+
.distinct
83+
.sortWith(_._2 > _._2)
84+
}
85+
86+
/**
87+
* Generate combination of items by computing on FPTree,
88+
* the computation is done on each FPTree partitions.
89+
*/
90+
private def generateCombinations(
91+
data: RDD[Array[String]],
92+
minCount: Double,
93+
singleItem: Array[(String, Int)]): Array[(String, Int)] = {
94+
val single = data.context.broadcast(singleItem)
95+
data.flatMap(basket => createFPTree(basket, single))
96+
.groupByKey()
97+
.flatMap(partition => runFPTree(partition, minCount))
98+
.collect()
99+
}
100+
101+
/**
102+
* Create FP-Tree partition for the giving basket
103+
*/
104+
private def createFPTree(
105+
basket: Array[String],
106+
singleItem: Broadcast[Array[(String, Int)]]): Array[(String, Array[String])] = {
107+
var output = ArrayBuffer[(String, Array[String])]()
108+
var combination = ArrayBuffer[String]()
109+
val single = singleItem.value
110+
var items = ArrayBuffer[(String, Int)]()
111+
112+
// Filter the basket by single item pattern
113+
val iterator = basket.iterator
114+
while (iterator.hasNext){
115+
val item = iterator.next
116+
val opt = single.find(_._1.equals(item))
117+
if (opt != None) {
118+
items ++= opt
119+
}
120+
}
121+
122+
// Sort it and create the item combinations
123+
val sortedItems = items.sortWith(_._1 > _._1).sortWith(_._2 > _._2).toArray
124+
val itemIterator = sortedItems.iterator
125+
while (itemIterator.hasNext) {
126+
combination.clear()
127+
val item = itemIterator.next
128+
val firstNItems = sortedItems.take(sortedItems.indexOf(item))
129+
if (firstNItems.length > 0) {
130+
val iterator = firstNItems.iterator
131+
while (iterator.hasNext) {
132+
val elem = iterator.next
133+
combination += elem._1
134+
}
135+
output += ((item._1, combination.toArray))
136+
}
137+
}
138+
output.toArray
139+
}
140+
141+
/**
142+
* Generate frequent pattern by walking through the FPTree
143+
*/
144+
private def runFPTree(
145+
partition: (String, Iterable[Array[String]]),
146+
minCount: Double): Array[(String, Int)] = {
147+
val key = partition._1
148+
val value = partition._2
149+
val output = ArrayBuffer[(String, Int)]()
150+
val map = Map[String, Int]()
151+
152+
// Walk through the FPTree partition to generate all combinations that satisfy
153+
// the minimal support level.
154+
var k = 1
155+
while (k > 0) {
156+
map.clear()
157+
val iterator = value.iterator
158+
while (iterator.hasNext) {
159+
val pattern = iterator.next
160+
if (pattern.length >= k) {
161+
val combination = pattern.toList.combinations(k).toList
162+
val itemIterator = combination.iterator
163+
while (itemIterator.hasNext){
164+
val item = itemIterator.next
165+
val list2key: List[String] = (item :+ key).sortWith(_ > _)
166+
val newKey = list2key.mkString(" ")
167+
if (map.get(newKey) == None) {
168+
map(newKey) = 1
169+
} else {
170+
map(newKey) = map.apply(newKey) + 1
171+
}
172+
}
173+
}
174+
}
175+
var eligible: Array[(String, Int)] = null
176+
if (map.size != 0) {
177+
val candidate = map.filter(_._2 >= minCount)
178+
if (candidate.size != 0) {
179+
eligible = candidate.toArray
180+
output ++= eligible
181+
}
182+
}
183+
if ((eligible == null) || (eligible.length == 0)) {
184+
k = 0
185+
} else {
186+
k = k + 1
187+
}
188+
}
189+
output.toArray
190+
}
191+
}
192+
193+
/**
194+
* Top-level methods for calling FPGrowth.
195+
*/
196+
object FPGrowth{
197+
198+
/**
199+
* Generate a FPGrowth Model using the given minimal support level.
200+
*
201+
* @param data input baskets stored as `RDD[Array[String]]`
202+
* @param minSupport minimal support level, for example 0.5
203+
*/
204+
def train(data: RDD[Array[String]], minSupport: Double): FPGrowthModel = {
205+
new FPGrowth().setMinSupport(minSupport).run(data)
206+
}
207+
}
208+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm
19+
20+
/**
21+
* A FPGrowth Model for FPGrowth, each element is a frequent pattern with count.
22+
*/
23+
class FPGrowthModel (val frequentPattern: Array[(String, Int)]) extends Serializable {
24+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.fpm
18+
19+
import org.scalatest.FunSuite
20+
import org.apache.spark.mllib.util.LocalSparkContext
21+
22+
class FPGrowthSuite extends FunSuite with LocalSparkContext {
23+
24+
test("test FPGrowth algorithm")
25+
{
26+
val arr = FPGrowthSuite.createTestData()
27+
28+
assert(arr.length === 6)
29+
val dataSet = sc.parallelize(arr)
30+
assert(dataSet.count() == 6)
31+
val rdd = dataSet.map(line => line.split(" "))
32+
assert(rdd.count() == 6)
33+
34+
val algorithm = new FPGrowth()
35+
algorithm.setMinSupport(0.9)
36+
assert(algorithm.run(rdd).frequentPattern.length == 0)
37+
algorithm.setMinSupport(0.8)
38+
assert(algorithm.run(rdd).frequentPattern.length == 1)
39+
algorithm.setMinSupport(0.7)
40+
assert(algorithm.run(rdd).frequentPattern.length == 1)
41+
algorithm.setMinSupport(0.6)
42+
assert(algorithm.run(rdd).frequentPattern.length == 2)
43+
algorithm.setMinSupport(0.5)
44+
assert(algorithm.run(rdd).frequentPattern.length == 18)
45+
algorithm.setMinSupport(0.4)
46+
assert(algorithm.run(rdd).frequentPattern.length == 18)
47+
algorithm.setMinSupport(0.3)
48+
assert(algorithm.run(rdd).frequentPattern.length == 54)
49+
algorithm.setMinSupport(0.2)
50+
assert(algorithm.run(rdd).frequentPattern.length == 54)
51+
algorithm.setMinSupport(0.1)
52+
assert(algorithm.run(rdd).frequentPattern.length == 625)
53+
}
54+
}
55+
56+
object FPGrowthSuite
57+
{
58+
/**
59+
* Create test data set
60+
*/
61+
def createTestData():Array[String] =
62+
{
63+
val arr = Array[String](
64+
"r z h k p",
65+
"z y x w v u t s",
66+
"s x o n r",
67+
"x z y m t s q e",
68+
"z",
69+
"x z y r q t p")
70+
arr
71+
}
72+
}

0 commit comments

Comments
 (0)