Skip to content

Commit ebde0c6

Browse files
authored
Make num_top_classes parameter's default value equal to 2 (#48119) (#48202)
1 parent 7fba568 commit ebde0c6

File tree

9 files changed

+96
-27
lines changed

9 files changed

+96
-27
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) {
4848
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
4949
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5050
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
51+
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
5152

5253
private static final ConstructingObjectParser<Classification, Void> PARSER =
5354
new ConstructingObjectParser<>(
@@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) {
6162
(Integer) a[4],
6263
(Double) a[5],
6364
(String) a[6],
64-
(Double) a[7]));
65+
(Double) a[7],
66+
(Integer) a[8]));
6567

6668
static {
6769
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) {
7274
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
7375
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
7476
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
77+
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
7578
}
7679

7780
private final String dependentVariable;
@@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) {
8285
private final Double featureBagFraction;
8386
private final String predictionFieldName;
8487
private final Double trainingPercent;
88+
private final Integer numTopClasses;
8589

8690
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
8791
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
88-
@Nullable Double trainingPercent) {
92+
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
8993
this.dependentVariable = Objects.requireNonNull(dependentVariable);
9094
this.lambda = lambda;
9195
this.gamma = gamma;
@@ -94,6 +98,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
9498
this.featureBagFraction = featureBagFraction;
9599
this.predictionFieldName = predictionFieldName;
96100
this.trainingPercent = trainingPercent;
101+
this.numTopClasses = numTopClasses;
97102
}
98103

99104
@Override
@@ -133,6 +138,10 @@ public Double getTrainingPercent() {
133138
return trainingPercent;
134139
}
135140

141+
public Integer getNumTopClasses() {
142+
return numTopClasses;
143+
}
144+
136145
@Override
137146
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
138147
builder.startObject();
@@ -158,14 +167,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
158167
if (trainingPercent != null) {
159168
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
160169
}
170+
if (numTopClasses != null) {
171+
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
172+
}
161173
builder.endObject();
162174
return builder;
163175
}
164176

165177
@Override
166178
public int hashCode() {
167179
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
168-
trainingPercent);
180+
trainingPercent, numTopClasses);
169181
}
170182

171183
@Override
@@ -180,7 +192,8 @@ public boolean equals(Object o) {
180192
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
181193
&& Objects.equals(featureBagFraction, that.featureBagFraction)
182194
&& Objects.equals(predictionFieldName, that.predictionFieldName)
183-
&& Objects.equals(trainingPercent, that.trainingPercent);
195+
&& Objects.equals(trainingPercent, that.trainingPercent)
196+
&& Objects.equals(numTopClasses, that.numTopClasses);
184197
}
185198

186199
@Override
@@ -197,6 +210,7 @@ public static class Builder {
197210
private Double featureBagFraction;
198211
private String predictionFieldName;
199212
private Double trainingPercent;
213+
private Integer numTopClasses;
200214

201215
private Builder(String dependentVariable) {
202216
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) {
237251
return this;
238252
}
239253

254+
public Builder setNumTopClasses(Integer numTopClasses) {
255+
this.numTopClasses = numTopClasses;
256+
return this;
257+
}
258+
240259
public Classification build() {
241260
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
242-
trainingPercent);
261+
trainingPercent, numTopClasses);
243262
}
244263
}
245264
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,8 +1296,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
12961296
.setDest(DataFrameAnalyticsDest.builder()
12971297
.setIndex("put-test-dest-index")
12981298
.build())
1299-
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression
1300-
.builder("my_dependent_variable")
1299+
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
13011300
.setTrainingPercent(80.0)
13021301
.build())
13031302
.setDescription("this is a regression")
@@ -1331,9 +1330,9 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
13311330
.setDest(DataFrameAnalyticsDest.builder()
13321331
.setIndex("put-test-dest-index")
13331332
.build())
1334-
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification
1335-
.builder("my_dependent_variable")
1333+
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
13361334
.setTrainingPercent(80.0)
1335+
.setNumTopClasses(1)
13371336
.build())
13381337
.setDescription("this is a classification")
13391338
.build();

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,6 +2951,7 @@ public void testPutDataFrameAnalytics() throws Exception {
29512951
.setFeatureBagFraction(0.4) // <6>
29522952
.setPredictionFieldName("my_prediction_field_name") // <7>
29532953
.setTrainingPercent(50.0) // <8>
2954+
.setNumTopClasses(1) // <9>
29542955
.build();
29552956
// end::put-data-frame-analytics-classification
29562957

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public static Classification randomClassification() {
3434
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
3535
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
3636
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
37+
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
3738
.build();
3839
}
3940

docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
118118
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
119119
<7> The name of the prediction field in the results object.
120120
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
121+
<9> The number of top classes to be reported in the results. Defaults to 2.
121122

122123
===== Regression
123124

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
6767
.flatMap(Set::stream)
6868
.collect(Collectors.toSet()));
6969

70+
/**
71+
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
72+
* This way the user can see if the prediction was made with confidence they need.
73+
*/
74+
private static final int DEFAULT_NUM_TOP_CLASSES = 2;
75+
7076
private final String dependentVariable;
7177
private final BoostedTreeParams boostedTreeParams;
7278
private final String predictionFieldName;
@@ -87,7 +93,7 @@ public Classification(String dependentVariable,
8793
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
8894
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
8995
this.predictionFieldName = predictionFieldName;
90-
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
96+
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
9197
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
9298
}
9399

@@ -107,6 +113,10 @@ public String getDependentVariable() {
107113
return dependentVariable;
108114
}
109115

116+
public int getNumTopClasses() {
117+
return numTopClasses;
118+
}
119+
110120
public double getTrainingPercent() {
111121
return trainingPercent;
112122
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
2121

22+
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
23+
2224
@Override
2325
protected Classification doParseInstance(XContentParser parser) throws IOException {
2426
return Classification.fromXContent(parser, false);
@@ -43,32 +45,68 @@ protected Writeable.Reader<Classification> instanceReader() {
4345
return Classification::new;
4446
}
4547

46-
public void testConstructor_GivenTrainingPercentIsNull() {
47-
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null);
48-
assertThat(classification.getTrainingPercent(), equalTo(100.0));
49-
}
50-
51-
public void testConstructor_GivenTrainingPercentIsBoundary() {
52-
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0);
53-
assertThat(classification.getTrainingPercent(), equalTo(1.0));
54-
classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0);
55-
assertThat(classification.getTrainingPercent(), equalTo(100.0));
56-
}
57-
5848
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
5949
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
60-
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999));
50+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999));
6151

6252
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
6353
}
6454

6555
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
6656
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
67-
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001));
57+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001));
6858

6959
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
7060
}
7161

62+
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
63+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
64+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0));
65+
66+
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
67+
}
68+
69+
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
70+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
71+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0));
72+
73+
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
74+
}
75+
76+
public void testGetNumTopClasses() {
77+
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
78+
assertThat(classification.getNumTopClasses(), equalTo(7));
79+
80+
// Boundary condition: num_top_classes == 0
81+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0);
82+
assertThat(classification.getNumTopClasses(), equalTo(0));
83+
84+
// Boundary condition: num_top_classes == 1000
85+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0);
86+
assertThat(classification.getNumTopClasses(), equalTo(1000));
87+
88+
// num_top_classes == null, default applied
89+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0);
90+
assertThat(classification.getNumTopClasses(), equalTo(2));
91+
}
92+
93+
public void testGetTrainingPercent() {
94+
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
95+
assertThat(classification.getTrainingPercent(), equalTo(50.0));
96+
97+
// Boundary condition: training_percent == 1.0
98+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0);
99+
assertThat(classification.getTrainingPercent(), equalTo(1.0));
100+
101+
// Boundary condition: training_percent == 100.0
102+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0);
103+
assertThat(classification.getTrainingPercent(), equalTo(100.0));
104+
105+
// training_percent == null, default applied
106+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null);
107+
assertThat(classification.getTrainingPercent(), equalTo(100.0));
108+
}
109+
72110
public void testFieldCardinalityLimitsIsNonNull() {
73111
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
74112
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
8181
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
8282
assertThat(resultsObject.containsKey("is_training"), is(true));
8383
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
84-
assertThat(resultsObject.containsKey("top_classes"), is(false));
84+
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
8585
}
8686

8787
assertProgress(jobId, 100, 100, 100, 100);
@@ -118,7 +118,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
118118
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
119119
assertThat(resultsObject.containsKey("is_training"), is(true));
120120
assertThat(resultsObject.get("is_training"), is(true));
121-
assertThat(resultsObject.containsKey("top_classes"), is(false));
121+
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
122122
}
123123

124124
assertProgress(jobId, 100, 100, 100, 100);

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1810,7 +1810,7 @@ setup:
18101810
"maximum_number_trees": 400,
18111811
"feature_bag_fraction": 0.3,
18121812
"training_percent": 60.3,
1813-
"num_top_classes": 0
1813+
"num_top_classes": 2
18141814
}
18151815
}}
18161816
- is_true: create_time

0 commit comments

Comments
 (0)