Skip to content

Commit f53d2da

Browse files
committed
IGNITE-13713 Add target encoding preprocessor
1 parent 2c3d19c commit f53d2da

File tree

11 files changed

+848
-4
lines changed

11 files changed

+848
-4
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.ignite.examples.ml.preprocessing.encoding;
19+
20+
import java.io.FileNotFoundException;
21+
import java.util.Arrays;
22+
import java.util.HashSet;
23+
import java.util.Set;
24+
import org.apache.ignite.Ignite;
25+
import org.apache.ignite.IgniteCache;
26+
import org.apache.ignite.Ignition;
27+
import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
28+
import org.apache.ignite.examples.ml.util.SandboxMLCache;
29+
import org.apache.ignite.ml.composition.ModelsComposition;
30+
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
31+
import org.apache.ignite.ml.composition.boosting.convergence.median.MedianOfMedianConvergenceCheckerFactory;
32+
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
33+
import org.apache.ignite.ml.dataset.feature.extractor.impl.ObjectArrayVectorizer;
34+
import org.apache.ignite.ml.preprocessing.Preprocessor;
35+
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
36+
import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
37+
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
38+
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
39+
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
40+
41+
/**
42+
* Example that shows how to use Target Encoder preprocessor to encode labels presented as a mean target value.
43+
* <p>
44+
* Code in this example launches Ignite grid and fills the cache with test data (based on mushrooms dataset).</p>
45+
* <p>
46+
* After that it defines preprocessors that extract features from an upstream data and encode category with avarage
47+
* target value (categories). </p>
48+
* <p>
49+
* Then, it trains the model based on the processed data using gradient boosing decision tree classification.</p>
50+
* <p>
51+
* Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p>
52+
*
53+
* <p>Daniele Miccii-Barreca (2001). A Preprocessing Scheme for High-Cardinality Categorical
54+
* Attributes in Classification and Prediction Problems. SIGKDD Explor. Newsl. 3, 1.
55+
* From http://dx.doi.org/10.1145/507533.507538</p>
56+
*/
57+
public class TargetEncoderExample {
58+
/**
59+
* Run example.
60+
*/
61+
public static void main(String[] args) {
62+
System.out.println();
63+
System.out.println(">>> Train Gradient Boosing Decision Tree model on amazon-employee-access-challenge_train.csv dataset.");
64+
65+
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
66+
try {
67+
IgniteCache<Integer, Object[]> dataCache = new SandboxMLCache(ignite)
68+
.fillObjectCacheWithCategoricalData(MLSandboxDatasets.AMAZON_EMPLOYEE_ACCESS);
69+
70+
Set<Integer> featuresIndexies = new HashSet<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
71+
Set<Integer> targetEncodedfeaturesIndexies = new HashSet<>(Arrays.asList(1, 5, 6));
72+
Integer targetIndex = 0;
73+
74+
final Vectorizer<Integer, Object[], Integer, Object> vectorizer = new ObjectArrayVectorizer<Integer>(featuresIndexies.toArray(new Integer[0]))
75+
.labeled(targetIndex);
76+
77+
Preprocessor<Integer, Object[]> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>()
78+
.withEncoderType(EncoderType.STRING_ENCODER)
79+
.withEncodedFeature(0)
80+
.withEncodedFeatures(featuresIndexies)
81+
.fit(ignite,
82+
dataCache,
83+
vectorizer
84+
);
85+
86+
Preprocessor<Integer, Object[]> targetEncoderProcessor = new EncoderTrainer<Integer, Object[]>()
87+
.withEncoderType(EncoderType.TARGET_ENCODER)
88+
.labeled(0)
89+
.withEncodedFeatures(targetEncodedfeaturesIndexies)
90+
.minSamplesLeaf(1)
91+
.minCategorySize(1L)
92+
.smoothing(1d)
93+
.fit(ignite,
94+
dataCache,
95+
strEncoderPreprocessor
96+
);
97+
98+
Preprocessor<Integer, Object[]> lbEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>()
99+
.withEncoderType(EncoderType.LABEL_ENCODER)
100+
.fit(ignite,
101+
dataCache,
102+
targetEncoderProcessor
103+
);
104+
105+
GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.5, 500, 4, 0.)
106+
.withCheckConvergenceStgyFactory(new MedianOfMedianConvergenceCheckerFactory(0.1));
107+
108+
// Train model.
109+
ModelsComposition mdl = trainer.fit(
110+
ignite,
111+
dataCache,
112+
lbEncoderPreprocessor
113+
);
114+
115+
System.out.println("\n>>> Trained model: " + mdl);
116+
117+
double accuracy = Evaluator.evaluate(
118+
dataCache,
119+
mdl,
120+
lbEncoderPreprocessor,
121+
new Accuracy()
122+
);
123+
124+
System.out.println("\n>>> Accuracy " + accuracy);
125+
System.out.println("\n>>> Test Error " + (1 - accuracy));
126+
127+
System.out.println(">>> Train Gradient Boosing Decision Tree model on amazon-employee-access-challenge_train.csv dataset.");
128+
129+
}
130+
catch (FileNotFoundException e) {
131+
e.printStackTrace();
132+
}
133+
}
134+
finally {
135+
System.out.flush();
136+
}
137+
}
138+
}

examples/src/main/java/org/apache/ignite/examples/ml/util/MLSandboxDatasets.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ public enum MLSandboxDatasets {
6868
MIXED_DATASET("examples/src/main/resources/datasets/mixed_dataset.csv", true, ","),
6969

7070
/** A dataset with categorical features and labels. */
71-
MUSHROOMS("examples/src/main/resources/datasets/mushrooms.csv", true, ",");
71+
MUSHROOMS("examples/src/main/resources/datasets/mushrooms.csv", true, ","),
72+
73+
/** A dataset with categorical features and labels. */
74+
AMAZON_EMPLOYEE_ACCESS("examples/src/main/resources/datasets/amazon-employee-access-challenge_train.csv", true, ",");
7275

7376
/** Filename. */
7477
private final String filename;
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
ACTION,RESOURCE,MGR_ID,ROLE_ROLLUP_1,ROLE_ROLLUP_2,ROLE_DEPTNAME,ROLE_TITLE,ROLE_FAMILY_DESC,ROLE_FAMILY,ROLE_CODE
2+
1,39353,85475,117961,118300,123472,117905,117906,290919,117908
3+
1,17183,1540,117961,118343,123125,118536,118536,308574,118539
4+
1,36724,14457,118219,118220,117884,117879,267952,19721,117880
5+
1,36135,5396,117961,118343,119993,118321,240983,290919,118322
6+
1,42680,5905,117929,117930,119569,119323,123932,19793,119325
7+
0,45333,14561,117951,117952,118008,118568,118568,19721,118570
8+
1,25993,17227,117961,118343,123476,118980,301534,118295,118982
9+
1,19666,4209,117961,117969,118910,126820,269034,118638,126822
10+
1,31246,783,117961,118413,120584,128230,302830,4673,128231
11+
1,78766,56683,118079,118080,117878,117879,304519,19721,117880
12+
1,4675,3005,117961,118413,118481,118784,117906,290919,118786
13+
1,15030,94005,117902,118041,119238,119093,138522,119095,119096
14+
1,79954,46608,118315,118463,122636,120773,123148,118960,120774
15+
1,4675,50997,91261,118026,118202,119962,168365,118205,119964
16+
1,95836,18181,117961,118343,118514,118321,117906,290919,118322
17+
1,19484,6657,118219,118220,118221,117885,117886,117887,117888
18+
1,114267,23136,117961,118052,119742,118321,117906,290919,118322
19+
1,35197,57715,117961,118446,118701,118702,118703,118704,118705
20+
1,86316,7002,117961,118343,123125,118278,132715,290919,118279
21+
1,27785,5636,117961,118413,122007,118321,117906,290919,118322
22+
1,37427,5220,117961,118300,118458,120006,303717,118424,120008
23+
1,15672,111936,117961,118300,118783,117905,240983,290919,117908
24+
1,92885,744,117961,118300,119181,118777,279443,308574,118779
25+
1,1020,85475,117961,118300,120410,118321,117906,290919,118322
26+
1,4675,7551,117961,118052,118867,118259,117906,290919,118261
27+
1,41334,28253,118315,118463,123089,118259,128796,290919,118261
28+
1,77385,14829,117961,118052,119986,117905,117906,290919,117908
29+
1,20273,11506,118216,118587,118846,179731,128361,117887,117973
30+
1,78098,46556,118090,118091,117884,118568,165015,19721,118570
31+
1,79328,4219,117961,118300,120312,120313,144958,118424,120315
32+
1,23921,4953,117961,118343,119598,120344,310997,118424,120346
33+
1,34687,815,117961,118300,123719,117905,117906,290919,117908
34+
1,43452,169112,117902,118041,119781,118563,121024,270488,118565
35+
1,33248,4929,117961,118300,118825,118826,226343,118424,118828
36+
1,78282,7445,117961,118343,122299,118054,121350,117887,118055
37+
1,17183,794,118752,119070,117945,280788,152940,292795,119082
38+
1,38658,1912,119134,119135,118042,120097,174445,270488,120099
39+
1,14354,50368,117926,118266,117884,118568,281735,19721,118570
40+
1,45019,1080,117961,118327,118378,120952,120953,118453,120954
41+
1,13878,1541,117961,118225,123173,120812,123174,118638,120814
42+
1,14570,46805,117929,117930,117920,118568,281735,19721,118570
43+
0,74310,49521,117961,118300,118301,119849,235245,118638,119851
44+
1,6977,1398,117961,118300,120722,118784,130735,290919,118786
45+
1,31613,5899,117961,118327,120318,118777,296252,308574,118779
46+
1,1020,21127,117961,118052,119408,118777,279443,308574,118779
47+
1,32270,3887,117961,118343,120347,120348,265969,118295,120350
48+
1,19629,19645,117961,118413,118481,118784,240983,290919,118786
49+
1,15702,1938,117961,118300,118066,120560,304465,118643,120562
50+
1,113037,5396,117961,118343,119993,120773,118959,118960,120774
51+
1,20279,17695,117890,117891,117878,117879,117879,19721,117880
52+
1,80746,16690,117961,118446,119064,122022,131302,119221,122024
53+
1,80263,36145,117961,118052,120304,307024,311622,118331,118332
54+
1,73753,70062,117961,118386,118746,117905,117906,290919,117908
55+
1,39883,7551,117961,118052,118867,117905,172635,290919,117908
56+
1,25993,7023,117961,117962,119223,118259,118260,290919,118261
57+
0,78106,50613,117916,118150,118810,118568,159905,19721,118570
58+
1,33150,1915,117961,118300,119181,118784,117906,290919,118786
59+
1,34817,5899,117961,118327,120318,118641,240982,118643,118644
60+
1,28354,3860,117961,118446,120317,118321,117906,290919,118322
61+
1,33642,13196,117951,117952,117941,117879,117897,19721,117880
62+
1,26430,56310,118212,118580,117895,117896,117913,117887,117898
63+
1,28149,50120,91261,118026,119507,118321,117906,290919,118322
64+
1,40867,6736,117961,117969,6725,122290,268766,6725,122292
65+
1,20293,273476,117926,118266,117920,118568,310732,19721,118570
66+
1,36020,2163,118219,118220,120694,118777,130218,308574,118779
67+
1,60006,16821,117961,118225,120535,118396,269406,118398,118399
68+
0,35043,14800,117961,117962,118352,118784,117906,290919,118786
69+
1,17308,4088,117961,118300,118458,118728,223125,118295,118730
70+
0,15716,18073,118256,118257,118623,118995,286106,292795,118997
71+
1,39883,55956,118555,118178,119262,117946,119727,292795,117948
72+
1,42031,88387,118315,118463,118522,119172,121927,118467,119174
73+
1,27124,2318,117961,118327,118933,117905,117906,290919,117908
74+
1,35498,18454,117961,118343,119598,125171,257115,118424,125173
75+
1,79168,58465,118602,118603,117941,117885,119621,117887,117888
76+
1,2252,782,117961,118413,127522,118784,240983,290919,118786
77+
1,45652,7338,117961,118225,119924,118321,118448,290919,118322
78+
1,23921,4145,117961,118300,120026,307024,303717,118331,118332
79+
1,95247,50690,118269,118270,117878,118568,118568,19721,118570
80+
1,78844,15645,117961,118052,122392,128903,160695,292795,128905
81+
1,19481,10627,118106,118107,119565,179731,155780,117887,117973
82+
1,18380,44022,117961,117962,122215,127782,130085,290919,127783
83+
1,37734,58406,117975,117976,117884,117885,117913,117887,117888
84+
1,3853,17550,117961,118446,118684,118321,117906,290919,118322
85+
1,278393,7076,117961,118225,120323,119093,136840,119095,119096
86+
1,35625,6454,117961,118343,118856,117905,240983,290919,117908
87+
1,35066,17465,91261,118026,118202,118278,118260,290919,118279
88+
1,3853,5043,117961,118300,118458,120006,310997,118424,120008
89+
1,41569,16671,117961,118052,118706,118523,310608,118331,118525
90+
1,25862,46224,117961,118327,118378,120952,143223,118453,120954
91+
1,75078,45963,117961,118386,118896,122645,309858,119221,122647
92+
1,1020,1483,117961,117962,118840,118641,306399,118643,118644
93+
0,22956,3967,117961,118052,118706,118321,117906,290919,118322
94+
1,20364,2612,117961,118386,123901,117905,117906,290919,117908
95+
1,28943,7547,117961,118052,118933,118784,213944,290919,118786
96+
1,75329,17414,118752,119070,118042,118043,151099,270488,118046
97+
1,41569,70066,91261,118026,118202,117905,117906,290919,117908
98+
1,4684,50806,117961,118446,119961,118259,118260,290919,118261
99+
1,77943,4478,117961,118386,118692,118321,117906,290919,118322
100+
1,38860,15541,118573,118574,118556,280788,127423,292795,119082

modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPartitionData.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.ignite.ml.preprocessing.encoding;
1919

2020
import java.util.Map;
21+
import org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter;
2122

2223
/**
2324
* Partition data used in Encoder preprocessor.
@@ -29,6 +30,9 @@ public class EncoderPartitionData implements AutoCloseable {
2930
/** Frequencies of categories for label presented as strings. */
3031
private Map<String, Integer> labelFrequencies;
3132

33+
/** Target encoding meta of categories for label presented as strings. */
34+
private TargetCounter[] targetCounters;
35+
3236
/**
3337
* Constructs a new instance of String Encoder partition data.
3438
*/
@@ -53,6 +57,15 @@ public Map<String, Integer> labelFrequencies() {
5357
return labelFrequencies;
5458
}
5559

60+
/**
61+
* Gets the map of target encoding meta by value in partition for label.
62+
*
63+
* @return The target encoding meta.
64+
*/
65+
public TargetCounter[] targetCounters() {
66+
return targetCounters;
67+
}
68+
5669
/**
5770
* Sets the array of maps of frequencies by value in partition for each feature in the dataset.
5871
*
@@ -75,6 +88,12 @@ public EncoderPartitionData withLabelFrequencies(Map<String, Integer> labelFrequ
7588
return this;
7689
}
7790

91+
/** */
92+
public EncoderPartitionData withTargetCounters(TargetCounter[] targetCounters) {
93+
this.targetCounters = targetCounters;
94+
return this;
95+
}
96+
7897
/** */
7998
@Override public void close() {
8099
// Do nothing, GC will clean up.

0 commit comments

Comments
 (0)