-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCSTreeCostSplitCrit.java
120 lines (98 loc) · 3.64 KB
/
CSTreeCostSplitCrit.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/*
* CSTreeCostSplitCrit.java
*
*/
package weka.classifiers.trees.j48;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.CostMatrix;
import weka.classifiers.trees.CSForest;
import weka.core.Attribute;
import weka.core.ContingencyTables;
import weka.core.RevisionUtils;
import weka.core.Utils;
/**
* Class for computing the gain ratio for a given distribution.
*
* @author modified by Michael Furner (mfurner@csu.edu.au) (originally by Eibe Frank (eibe@cs.waikato.ac.nz))
* @version $Revision: 10169 $
*/
public final class CSTreeCostSplitCrit extends SplitCriterion {
/**
* for serialization
*/
private static final long serialVersionUID = -433336694718670930L;
/**
* The cost matrix
*/
protected CostMatrix m_CostMatrix;
/**
* Make the split criteria with a cost matrix
*
* @param cm - the cost matrix
*/
public CSTreeCostSplitCrit(CostMatrix cm) {
m_CostMatrix = cm;
}
/**
* This method is an implementation of the cost criterion for the given
* distribution.
*
* @param bags - the distribution in the split
* @param classAttr - not used but required
* @return the cost of the split
* @throws java.lang.Exception - if there is a problem with the cost matrix
*/
public final double splitCritValue(Distribution bags, Attribute classAttr) throws Exception {
double[] costPerBag = new double[bags.actualNumBags()];
for (int i = 0; i < bags.actualNumBags(); i++) {
//dont count the cost of empty bags
if(bags.perBag(i) == 0)
continue;
double[] costPerClassArray = new double[bags.numClasses()];
for (int classToBeClassifiedAs = 0; classToBeClassifiedAs < bags.numClasses(); classToBeClassifiedAs++) {
double runningSum = 0;
for (int actualClassNumbersIdx = 0; actualClassNumbersIdx < bags.numClasses(); actualClassNumbersIdx++) {
double c_ij = m_CostMatrix.getElement(classToBeClassifiedAs, actualClassNumbersIdx) * bags.perClassPerBag(i, actualClassNumbersIdx);
runningSum += c_ij;
}
costPerClassArray[classToBeClassifiedAs] = runningSum; //check this
}
double numerator = 1;
double denominator = 0;
for (int k = 0; k < costPerClassArray.length; k++) {
numerator *= costPerClassArray[k];
denominator += costPerClassArray[k];
}
costPerBag[i] = 2 * numerator / denominator;
}
double costForSplit = 0;
for (int i = 0; i < costPerBag.length; i++) {
costForSplit += costPerBag[i];
}
return costForSplit;
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10169 $");
}
}