Skip to content

Commit 1809226

Browse files
author
Boyce.Zhang
committed
添加关联规则生成
1 parent 494e908 commit 1809226

File tree

6 files changed

+264
-14
lines changed

6 files changed

+264
-14
lines changed

src/apriori/AssociationRule.java

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package apriori;
2+
3+
import common.utils.AssertUtils;
4+
5+
import java.util.ArrayList;
6+
import java.util.Collections;
7+
import java.util.List;
8+
9+
/**
10+
* User: Boyce
11+
* Date: 11/12/15
12+
* Time: 15:07
13+
*/
14+
// 关联规则对象
15+
public class AssociationRule {
16+
17+
// 关联规则前件
18+
private List<SingleItem> frontItem;
19+
// 关联规则后建
20+
private List<SingleItem> afterItem;
21+
22+
public AssociationRule(List<SingleItem> frontItem, List<SingleItem> afterItem) {
23+
AssertUtils.assertNotEmpty(frontItem);
24+
AssertUtils.assertNotEmpty(afterItem);
25+
26+
this.frontItem = new ArrayList<SingleItem>(frontItem);
27+
this.afterItem = new ArrayList<SingleItem>(afterItem);
28+
}
29+
30+
public AssociationRule(SingleItem frontItem, List<SingleItem> afterItem) {
31+
this(asList(frontItem), afterItem);
32+
}
33+
34+
public AssociationRule(List<SingleItem> frontItem, SingleItem afterItem) {
35+
this(frontItem, asList(afterItem));
36+
}
37+
38+
public AssociationRule(SingleItem frontItem, SingleItem afterItem) {
39+
this(asList(frontItem), asList(afterItem));
40+
}
41+
42+
private static List<SingleItem> asList(SingleItem singleItem) {
43+
List<SingleItem> list = new ArrayList<SingleItem>();
44+
list.add(singleItem);
45+
return list;
46+
}
47+
48+
public List<SingleItem> getFrontItem() {
49+
return frontItem;
50+
}
51+
52+
public List<SingleItem> getAfterItem() {
53+
return afterItem;
54+
}
55+
56+
@Override
57+
public boolean equals(Object o) {
58+
if (this == o) return true;
59+
if (!(o instanceof AssociationRule)) return false;
60+
61+
AssociationRule that = (AssociationRule) o;
62+
63+
if (!afterItem.equals(that.afterItem)) return false;
64+
if (!frontItem.equals(that.frontItem)) return false;
65+
66+
return true;
67+
}
68+
69+
@Override
70+
public int hashCode() {
71+
int result = frontItem.hashCode();
72+
result = 31 * result + afterItem.hashCode();
73+
return result;
74+
}
75+
76+
@Override
77+
public String toString() {
78+
final StringBuilder sb = new StringBuilder("AssociationRule{");
79+
sb.append(frontItem);
80+
sb.append(" > ");
81+
sb.append(afterItem);
82+
sb.append('}');
83+
return sb.toString();
84+
}
85+
}

src/apriori/Item.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ public class Item {
1717
// the item contains transactions, transactions size is the item count.
1818
// 每个item保存其所在的transaction 的名字集合
1919
protected Set<String> ownerTransactionNames;
20+
protected int count;
2021

2122
// item length, example: a :length=1, a,b: length=2
2223
protected int length;
@@ -42,6 +43,10 @@ public void addTransactions(Set<String> transactionNames) {
4243
// count(X) = |{t(i)|X属于t(i), t(i)属于ts}|
4344
// item 的count 是 item所属于的t的数量
4445
public int count() {
46+
if (this.ownerTransactionNames == null ||
47+
this.ownerTransactionNames.isEmpty())
48+
return count;
49+
4550
return this.ownerTransactionNames.size();
4651
}
4752

@@ -54,6 +59,8 @@ public int length() {
5459
}
5560

5661
public void clear() {
62+
// 保存下count
63+
this.count = this.ownerTransactionNames.size();
5764
this.ownerTransactionNames.clear();
5865
}
5966
}

src/apriori/Itemset.java

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
import common.utils.AssertUtils;
44

5-
import java.util.ArrayList;
6-
import java.util.HashSet;
7-
import java.util.List;
8-
import java.util.Set;
5+
import java.util.*;
96

107
/**
118
* User: Boyce
@@ -14,6 +11,10 @@
1411
*/
1512
public class Itemset implements Cloneable {
1613
protected List<Item> items;
14+
15+
// 生成该k项集的(k-1)项集
16+
protected Itemset parentItemset;
17+
1718
// TODO comment
1819
protected int cardinal_number;
1920
protected int length;
@@ -48,6 +49,7 @@ public int length() {
4849
public Itemset frequent_gen(Transactions transactions) {
4950
int c_cardinal_number = this.cardinal_number + 1;
5051
Itemset frequentItemset = new Itemset(c_cardinal_number, transactions);
52+
frequentItemset.parentItemset = this;
5153

5254
MultiItem multiItem;
5355
for(int i=0; i<length; i++) {
@@ -133,18 +135,106 @@ private void addItemIfFrequent(Item item) {
133135
}
134136

135137
public static Itemset init(Transactions transactions) {
136-
Itemset candidateItemset = new Itemset(1, transactions);
138+
Itemset frequentItemset = new Itemset(1, transactions);
137139
List<SingleItem> singleItems = transactions.allSingleItems();
138-
for (Item item: singleItems) {
139-
candidateItemset.addItemIfFrequent(item);
140+
for (SingleItem item: singleItems) {
141+
frequentItemset.addItemIfFrequent(item);
140142
}
141-
return candidateItemset;
143+
return frequentItemset;
142144
}
143145

144146
public boolean isEmpty() {
145147
return this.items.isEmpty();
146148
}
147149

150+
// 该频繁项集生成关联规则
151+
public List<AssociationRule> rule_gen() {
152+
153+
if (1 == this.cardinal_number)
154+
return Collections.EMPTY_LIST;
155+
156+
List<AssociationRule> f_rules = new ArrayList<AssociationRule>();
157+
List<Item> items = this.items;
158+
159+
// 计算频繁项集中每个频繁项的关联规则
160+
for (int i=0; i<items.size(); i++) {
161+
MultiItem multiItem = (MultiItem)items.get(i);
162+
163+
// TODO 如果一条关联规则的后件为a,那么所有以a的非空子集作为后件的候选规则
164+
// TODO 都是关联规则,换句话说,如果以a为后件的候选规则不是关联规则,则
165+
// TODO 任何以a为子项的父集作为后件的候选规则都不是关联规则。
166+
167+
// 获取该频繁项的所有 1-候选规则
168+
List<AssociationRule> c1_rules = multiItem.a_1AssociationRules();
169+
this.recurse(f_rules, c1_rules, multiItem);
170+
}
171+
return f_rules;
172+
}
173+
174+
private void recurse(List<AssociationRule> f_rules, List<AssociationRule> c_rules, MultiItem multiItem) {
175+
for (AssociationRule rule: c_rules) {
176+
if (isFrequentRule(rule, multiItem)) {
177+
f_rules.add(rule);
178+
179+
List<AssociationRule> ck_rules = multiItem.a_kAssociationRules(rule.getAfterItem());
180+
this.recurse(f_rules, ck_rules, multiItem);
181+
}
182+
}
183+
}
184+
185+
private boolean isFrequentRule(AssociationRule rule, MultiItem item) {
186+
List<SingleItem> frontItem = rule.getFrontItem();
187+
188+
int n = frontItem.size();
189+
int n_parent = this.cardinal_number - n;
190+
191+
Itemset n_parent_itemset = this;
192+
for (int i=0; i<n_parent; i++)
193+
n_parent_itemset = n_parent_itemset.parentItemset;
194+
195+
int front_count = 0;
196+
if (n == 1)
197+
front_count = n_parent_itemset.count(frontItem.get(0));
198+
else
199+
front_count = n_parent_itemset.count(new MultiItem(frontItem));
200+
201+
double conf = (double)item.count/front_count;
202+
boolean isFrequentRule = conf >= this.transactions.minconf();
203+
204+
if (isFrequentRule) {
205+
System.out.println("> add a rule: " + rule + ", conf=" + item.count + "/" + front_count + " = " + (double)item.count/front_count);
206+
}
207+
return isFrequentRule;
208+
}
209+
210+
public int count(Item item) {
211+
int count = 0;
212+
for (int i=0; i<this.items.size(); i++) {
213+
Item it = this.items.get(i);
214+
if (it.equals(item)) {
215+
count = it.count();
216+
break;
217+
}
218+
}
219+
return count;
220+
}
221+
222+
public int count(SingleItem item) {
223+
int count = 0;
224+
for (int i=0; i<this.items.size(); i++) {
225+
SingleItem it = (SingleItem)this.items.get(i);
226+
if (it.value.equals(((SingleItem)(item.value)).value)) {
227+
count = it.count();
228+
break;
229+
}
230+
}
231+
return count;
232+
}
233+
234+
public Itemset getParentItemset() {
235+
return parentItemset;
236+
}
237+
148238
@Override
149239
public String toString() {
150240
final StringBuilder sb = new StringBuilder("Itemset{");

src/apriori/Main.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package apriori;
22

3+
import java.util.*;
4+
35
/**
46
* Created with IntelliJ IDEA.
57
* User: Boyce
@@ -9,7 +11,7 @@
911
*/
1012
public class Main {
1113
public static void main(String[] args) {
12-
Transactions transactions = new Transactions(0.25, 0.4);
14+
Transactions transactions = new Transactions(0.25, 0.7);
1315

1416
Transaction t1 = new Transaction("t1", transactions);
1517
t1.addItem("a");
@@ -57,5 +59,11 @@ public static void main(String[] args) {
5759
f = f.frequent_gen(transactions);
5860
System.out.println(f);
5961
}
62+
63+
f = f.parentItemset;
64+
System.out.println(f);
65+
66+
List<AssociationRule> rules = f.rule_gen();
67+
System.out.println("rules: " + rules);
6068
}
6169
}

src/apriori/MultiItem.java

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,27 @@ public MultiItem(List<SingleItem> items) {
2828
this.values.addAll(items);
2929
}
3030

31-
public void addValue(Object obj) {
32-
AssertUtils.assertNotNull(obj, "cannot add a null value into item.");
31+
public void addItem(SingleItem singleItem) {
32+
AssertUtils.assertNotNull(singleItem, "cannot add a null item into item.");
3333
AssertUtils.assertTrue(this.values.size() < this.length, "out of the item length.");
3434

35-
SingleItem singleItem = new SingleItem(obj);
36-
3735
// TODO 保证list里面的元素不重复
3836
if (!this.values.contains(singleItem))
3937
this.values.add(singleItem);
4038
}
4139

40+
public void addItems(List<SingleItem> singleItems) {
41+
for (SingleItem item: singleItems)
42+
this.addItem(item);
43+
}
44+
45+
public void addValue(Object obj) {
46+
AssertUtils.assertNotNull(obj, "cannot add a null value into item.");
47+
SingleItem singleItem = new SingleItem(obj);
48+
49+
this.addItem(singleItem);
50+
}
51+
4252
public void addValues(List<Object> objs) {
4353
AssertUtils.assertNotEmpty(objs, "cannot add a empty list into item.");
4454
for(Object obj: objs) {
@@ -97,6 +107,53 @@ public List<MultiItem> k_1Subset() {
97107
return subset;
98108
}
99109

110+
// 称以c个单值作为后件的候选规则为 c-候选规则,该方法生成该MultiItem的所有的 1-候选规则
111+
public List<AssociationRule> a_1AssociationRules() {
112+
List<AssociationRule> rules = new ArrayList<AssociationRule>(this.values.size());
113+
for (int i=0; i<this.values.size(); i++) {
114+
SingleItem afterItem = this.values.get(i);
115+
116+
List<SingleItem> frontItem = new ArrayList<SingleItem>(this.length-1);
117+
for (int j=0; j<this.values.size(); j++) {
118+
if (j != i) {
119+
frontItem.add(this.values.get(j));
120+
}
121+
}
122+
AssociationRule rule = new AssociationRule(frontItem, afterItem);
123+
rules.add(rule);
124+
}
125+
return rules;
126+
}
127+
128+
// 该方法生成后件包含 items 的 (items.size()+1)-候选规则
129+
public List<AssociationRule> a_kAssociationRules(List<SingleItem> items) {
130+
if (null == items || items.isEmpty() || items.size() >= this.values.size()
131+
|| !this.values.containsAll(items))
132+
return Collections.EMPTY_LIST;
133+
134+
List<SingleItem> otherItems = new ArrayList<SingleItem>(this.values);
135+
otherItems.removeAll(items);
136+
137+
List<AssociationRule> rules = new ArrayList<AssociationRule>(this.values.size()-items.size());
138+
for (int i=0; i<otherItems.size(); i++) {
139+
140+
List<SingleItem> afterItem = new ArrayList<SingleItem>(items.size()+1);
141+
afterItem.addAll(items);
142+
afterItem.add(otherItems.get(i));
143+
144+
List<SingleItem> frontItem = new ArrayList(otherItems.size()-1);
145+
for (int j=0; j<otherItems.size(); j++) {
146+
if (j != i) {
147+
frontItem.add(otherItems.get(j));
148+
}
149+
}
150+
AssociationRule rule = new AssociationRule(frontItem, afterItem);
151+
rules.add(rule);
152+
}
153+
154+
return rules;
155+
}
156+
100157
@Override
101158
public boolean equals(Object o) {
102159
if (this == o) return true;

src/apriori/example/Calculator.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public class Calculator {
2020
private final static String BASE_PATH = "/Users/Boyce/GitProjects/Algorithm/src/apriori/example/";
2121

2222
public static void main(String[] args) {
23-
Transactions transactions = new Transactions(0.5, 0.4);
23+
Transactions transactions = new Transactions(0.9, 0.4);
2424

2525
File file = new File(BASE_PATH + "data/accidents.dat");
2626
List<String> lines = FileUtils.readLines(file);
@@ -56,5 +56,8 @@ public static void main(String[] args) {
5656
}
5757

5858
System.out.println("calculate take: " + (System.currentTimeMillis()-start) + "ms");
59+
60+
System.out.println("rule: " + f.getParentItemset().rule_gen());
61+
5962
}
6063
}

0 commit comments

Comments
 (0)