Skip to content

Commit 298ef5b

Browse files
Jacky Limengxr
Jacky Li
authored andcommitted
[SPARK-5520][MLlib] Make FP-Growth implementation take generic item types (WIP)
Make FPGrowth.run API take generic item types: `def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item]` so that user can invoke it by run[String, Seq[String]], run[Int, Seq[Int]], run[Int, List[Int]], etc. Scala part is done, while java part is still in progress Author: Jacky Li <jacky.likun@huawei.com> Author: Jacky Li <jackylk@users.noreply.github.com> Author: Xiangrui Meng <meng@databricks.com> Closes #4340 from jackylk/SPARK-5520-WIP and squashes the following commits: f5acf84 [Jacky Li] Merge pull request #2 from mengxr/SPARK-5520 63073d0 [Xiangrui Meng] update to make generic FPGrowth Java-friendly 737d8bb [Jacky Li] fix scalastyle 793f85c [Jacky Li] add Java test case 7783351 [Jacky Li] add generic support in FPGrowth (cherry picked from commit e380d2d) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 4640623 commit 298ef5b

File tree

3 files changed

+170
-15
lines changed

3 files changed

+170
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,31 @@
1818
package org.apache.spark.mllib.fpm
1919

2020
import java.{util => ju}
21+
import java.lang.{Iterable => JavaIterable}
2122

2223
import scala.collection.mutable
24+
import scala.collection.JavaConverters._
25+
import scala.reflect.ClassTag
2326

24-
import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
27+
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
28+
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
29+
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
2530
import org.apache.spark.rdd.RDD
2631
import org.apache.spark.storage.StorageLevel
2732

28-
class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
33+
/**
34+
* Model trained by [[FPGrowth]], which holds frequent itemsets.
35+
* @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
36+
* @tparam Item item type
37+
*/
38+
class FPGrowthModel[Item: ClassTag](
39+
val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
40+
41+
/** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
42+
def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
43+
JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
44+
}
45+
}
2946

3047
/**
3148
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -69,7 +86,7 @@ class FPGrowth private (
6986
* @param data input data set, each element contains a transaction
7087
* @return an [[FPGrowthModel]]
7188
*/
72-
def run(data: RDD[Array[String]]): FPGrowthModel = {
89+
def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
7390
if (data.getStorageLevel == StorageLevel.NONE) {
7491
logWarning("Input data is not cached.")
7592
}
@@ -82,19 +99,24 @@ class FPGrowth private (
8299
new FPGrowthModel(freqItemsets)
83100
}
84101

102+
def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
103+
implicit val tag = fakeClassTag[Item]
104+
run(data.rdd.map(_.asScala.toArray))
105+
}
106+
85107
/**
86108
* Generates frequent items by filtering the input data using minimal support level.
87109
* @param minCount minimum count for frequent itemsets
88110
* @param partitioner partitioner used to distribute items
89111
* @return array of frequent pattern ordered by their frequencies
90112
*/
91-
private def genFreqItems(
92-
data: RDD[Array[String]],
113+
private def genFreqItems[Item: ClassTag](
114+
data: RDD[Array[Item]],
93115
minCount: Long,
94-
partitioner: Partitioner): Array[String] = {
116+
partitioner: Partitioner): Array[Item] = {
95117
data.flatMap { t =>
96118
val uniq = t.toSet
97-
if (t.length != uniq.size) {
119+
if (t.size != uniq.size) {
98120
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
99121
}
100122
t
@@ -114,11 +136,11 @@ class FPGrowth private (
114136
* @param partitioner partitioner used to distribute transactions
115137
* @return an RDD of (frequent itemset, count)
116138
*/
117-
private def genFreqItemsets(
118-
data: RDD[Array[String]],
139+
private def genFreqItemsets[Item: ClassTag](
140+
data: RDD[Array[Item]],
119141
minCount: Long,
120-
freqItems: Array[String],
121-
partitioner: Partitioner): RDD[(Array[String], Long)] = {
142+
freqItems: Array[Item],
143+
partitioner: Partitioner): RDD[(Array[Item], Long)] = {
122144
val itemToRank = freqItems.zipWithIndex.toMap
123145
data.flatMap { transaction =>
124146
genCondTransactions(transaction, itemToRank, partitioner)
@@ -139,9 +161,9 @@ class FPGrowth private (
139161
* @param partitioner partitioner used to distribute transactions
140162
* @return a map of (target partition, conditional transaction)
141163
*/
142-
private def genCondTransactions(
143-
transaction: Array[String],
144-
itemToRank: Map[String, Int],
164+
private def genCondTransactions[Item: ClassTag](
165+
transaction: Array[Item],
166+
itemToRank: Map[Item, Int],
145167
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
146168
val output = mutable.Map.empty[Int, Array[Int]]
147169
// Filter the basket by frequent items pattern and sort their ranks.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm;
19+
20+
import java.io.Serializable;
21+
import java.util.ArrayList;
22+
23+
import org.junit.After;
24+
import org.junit.Before;
25+
import org.junit.Test;
26+
import com.google.common.collect.Lists;
27+
import static org.junit.Assert.*;
28+
29+
import org.apache.spark.api.java.JavaRDD;
30+
import org.apache.spark.api.java.JavaSparkContext;
31+
32+
public class JavaFPGrowthSuite implements Serializable {
33+
private transient JavaSparkContext sc;
34+
35+
@Before
36+
public void setUp() {
37+
sc = new JavaSparkContext("local", "JavaFPGrowth");
38+
}
39+
40+
@After
41+
public void tearDown() {
42+
sc.stop();
43+
sc = null;
44+
}
45+
46+
@Test
47+
public void runFPGrowth() {
48+
49+
@SuppressWarnings("unchecked")
50+
JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
51+
Lists.newArrayList("r z h k p".split(" ")),
52+
Lists.newArrayList("z y x w v u t s".split(" ")),
53+
Lists.newArrayList("s x o n r".split(" ")),
54+
Lists.newArrayList("x z y m t s q e".split(" ")),
55+
Lists.newArrayList("z".split(" ")),
56+
Lists.newArrayList("x z y r q t p".split(" "))), 2);
57+
58+
FPGrowth fpg = new FPGrowth();
59+
60+
FPGrowthModel<String> model6 = fpg
61+
.setMinSupport(0.9)
62+
.setNumPartitions(1)
63+
.run(rdd);
64+
assertEquals(0, model6.javaFreqItemsets().count());
65+
66+
FPGrowthModel<String> model3 = fpg
67+
.setMinSupport(0.5)
68+
.setNumPartitions(2)
69+
.run(rdd);
70+
assertEquals(18, model3.javaFreqItemsets().count());
71+
72+
FPGrowthModel<String> model2 = fpg
73+
.setMinSupport(0.3)
74+
.setNumPartitions(4)
75+
.run(rdd);
76+
assertEquals(54, model2.javaFreqItemsets().count());
77+
78+
FPGrowthModel<String> model1 = fpg
79+
.setMinSupport(0.1)
80+
.setNumPartitions(8)
81+
.run(rdd);
82+
assertEquals(625, model1.javaFreqItemsets().count());
83+
}
84+
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2222

2323
class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
2424

25-
test("FP-Growth") {
25+
26+
test("FP-Growth using String type") {
2627
val transactions = Seq(
2728
"r z h k p",
2829
"z y x w v u t s",
@@ -70,4 +71,52 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
7071
.run(rdd)
7172
assert(model1.freqItemsets.count() === 625)
7273
}
74+
75+
test("FP-Growth using Int type") {
76+
val transactions = Seq(
77+
"1 2 3",
78+
"1 2 3 4",
79+
"5 4 3 2 1",
80+
"6 5 4 3 2 1",
81+
"2 4",
82+
"1 3",
83+
"1 7")
84+
.map(_.split(" ").map(_.toInt).toArray)
85+
val rdd = sc.parallelize(transactions, 2).cache()
86+
87+
val fpg = new FPGrowth()
88+
89+
val model6 = fpg
90+
.setMinSupport(0.9)
91+
.setNumPartitions(1)
92+
.run(rdd)
93+
assert(model6.freqItemsets.count() === 0)
94+
95+
val model3 = fpg
96+
.setMinSupport(0.5)
97+
.setNumPartitions(2)
98+
.run(rdd)
99+
assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
100+
"frequent itemsets should use primitive arrays")
101+
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
102+
(items.toSet, count)
103+
}
104+
val expected = Set(
105+
(Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
106+
(Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
107+
(Set(2, 4), 4L), (Set(1, 2, 3), 4L))
108+
assert(freqItemsets3.toSet === expected)
109+
110+
val model2 = fpg
111+
.setMinSupport(0.3)
112+
.setNumPartitions(4)
113+
.run(rdd)
114+
assert(model2.freqItemsets.count() === 15)
115+
116+
val model1 = fpg
117+
.setMinSupport(0.1)
118+
.setNumPartitions(8)
119+
.run(rdd)
120+
assert(model1.freqItemsets.count() === 65)
121+
}
73122
}

0 commit comments

Comments
 (0)