Skip to content

Commit 30d7f9b

Browse files
committed
[ML] adjusting feature importance mapping for multi-class support (elastic#53821)
Feature importance storage format is changing to encompass multi-class. Feature importance objects are now mapped as follows (logistic) Regression: ``` { "feature_name": "feature_0", "importance": -1.3 } ``` Multi-class [class names are `foo`, `bar`, `baz`] ``` { “feature_name”: “feature_0”, “importance”: 2.0, // sum(abs()) of class importances “foo”: 1.0, “bar”: 0.5, “baz”: -0.5 }, ``` This change adjusts the mapping creation for analytics so that the field is mapped as a `nested` type. Native side change: elastic/ml-cpp#1071
1 parent 88c5d52 commit 30d7f9b

File tree

7 files changed

+89
-33
lines changed

7 files changed

+89
-33
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,11 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
288288
@SuppressWarnings("unchecked")
289289
@Override
290290
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
291+
Map<String, Object> additionalProperties = new HashMap<>();
292+
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
291293
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
292294
if ((dependentVariableMapping instanceof Map) == false) {
293-
return Collections.emptyMap();
295+
return additionalProperties;
294296
}
295297
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
296298
// If the source field is an alias, fetch the concrete field that the alias points to.
@@ -301,9 +303,8 @@ public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mapping
301303
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
302304
// Hence, we need to check the "instanceof" condition again.
303305
if ((dependentVariableMapping instanceof Map) == false) {
304-
return Collections.emptyMap();
306+
return additionalProperties;
305307
}
306-
Map<String, Object> additionalProperties = new HashMap<>();
307308
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
308309
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
309310
return additionalProperties;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*//*
6+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
7+
* or more contributor license agreements. Licensed under the Elastic License;
8+
* you may not use this file except in compliance with the Elastic License.
9+
*/
10+
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
11+
12+
import org.elasticsearch.index.mapper.KeywordFieldMapper;
13+
import org.elasticsearch.index.mapper.NumberFieldMapper;
14+
15+
import java.util.Collections;
16+
import java.util.HashMap;
17+
import java.util.Map;
18+
19+
final class MapUtils {
20+
21+
private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
22+
static {
23+
Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
24+
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
25+
featureImportanceMappingProperties.put("importance",
26+
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
27+
Map<String, Object> featureImportanceMapping = new HashMap<>();
28+
// TODO sorted indices don't support nested types
29+
//featureImportanceMapping.put("dynamic", true);
30+
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
31+
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
32+
FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
33+
}
34+
35+
static Map<String, Object> featureImportanceMapping() {
36+
return FEATURE_IMPORTANCE_MAPPING;
37+
}
38+
39+
private MapUtils() {}
40+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
17+
import org.elasticsearch.index.mapper.NumberFieldMapper;
1718
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1819

1920
import java.io.IOException;
@@ -187,9 +188,13 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
187188

188189
@Override
189190
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
191+
Map<String, Object> additionalProperties = new HashMap<>();
192+
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
190193
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
191194
// high (over 10M) values of dependent variable.
192-
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
195+
additionalProperties.put(resultsFieldName + "." + predictionFieldName,
196+
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
197+
return additionalProperties;
193198
}
194199

195200
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import java.util.Set;
2828

2929
import static org.hamcrest.Matchers.allOf;
30-
import static org.hamcrest.Matchers.anEmptyMap;
3130
import static org.hamcrest.Matchers.containsString;
3231
import static org.hamcrest.Matchers.empty;
3332
import static org.hamcrest.Matchers.equalTo;
@@ -244,39 +243,45 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
244243
}
245244

246245
public void testGetExplicitlyMappedFields() {
247-
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
248-
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
246+
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
247+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
248+
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
249+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
249250
assertThat(
250251
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
251-
is(anEmptyMap()));
252-
assertThat(
253-
new Classification("foo").getExplicitlyMappedFields(
254-
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
255-
"results"),
252+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
253+
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
254+
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
255+
"results");
256+
assertThat(explicitlyMappedFields,
256257
allOf(
257258
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
258259
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
259-
assertThat(
260-
new Classification("foo").getExplicitlyMappedFields(
261-
new HashMap<String, Object>() {{
262-
put("foo", new HashMap<String, String>() {{
263-
put("type", "alias");
264-
put("path", "bar");
265-
}});
266-
put("bar", Collections.singletonMap("type", "long"));
267-
}},
268-
"results"),
260+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
261+
262+
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
263+
new HashMap<String, Object>() {{
264+
put("foo", new HashMap<String, String>() {{
265+
put("type", "alias");
266+
put("path", "bar");
267+
}});
268+
put("bar", Collections.singletonMap("type", "long"));
269+
}},
270+
"results");
271+
assertThat(explicitlyMappedFields,
269272
allOf(
270273
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
271274
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
275+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
276+
272277
assertThat(
273278
new Classification("foo").getExplicitlyMappedFields(
274279
Collections.singletonMap("foo", new HashMap<String, String>() {{
275280
put("type", "alias");
276281
put("path", "missing");
277282
}}),
278283
"results"),
279-
is(anEmptyMap()));
284+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
280285
}
281286

282287
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
1717

1818
import java.io.IOException;
19+
import java.util.Map;
1920
import java.util.Collections;
2021

2122
import static org.hamcrest.Matchers.allOf;
@@ -143,9 +144,9 @@ public void testFieldCardinalityLimitsIsEmpty() {
143144
}
144145

145146
public void testGetExplicitlyMappedFields() {
146-
assertThat(
147-
new Regression("foo").getExplicitlyMappedFields(null, "results"),
148-
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
147+
Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
148+
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
149+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
149150
}
150151

151152
public void testGetStateDocId() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ public void cleanup() {
7777
cleanUp();
7878
}
7979

80-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
8180
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
8281
initialize("classification_single_numeric_feature_and_mixed_data_set");
8382
String predictedClassField = KEYWORD_FIELD + "_prediction";
@@ -109,7 +108,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
109108
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
110109
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
111110
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
112-
assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true));
111+
@SuppressWarnings("unchecked")
112+
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
113+
assertThat(importanceArray, hasSize(greaterThan(0)));
113114
}
114115

115116
assertProgress(jobId, 100, 100, 100, 100);

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
import java.util.Map;
2828
import java.util.Set;
2929

30+
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
3031
import static org.hamcrest.Matchers.anyOf;
3132
import static org.hamcrest.Matchers.equalTo;
3233
import static org.hamcrest.Matchers.greaterThan;
34+
import static org.hamcrest.Matchers.hasSize;
3335
import static org.hamcrest.Matchers.is;
3436

3537
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@@ -50,7 +52,6 @@ public void cleanup() {
5052
cleanUp();
5153
}
5254

53-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
5455
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
5556
initialize("regression_single_numeric_feature_and_mixed_data_set");
5657
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
@@ -88,11 +89,13 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
8889
assertThat(resultsObject.containsKey(predictedClassField), is(true));
8990
assertThat(resultsObject.containsKey("is_training"), is(true));
9091
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
92+
@SuppressWarnings("unchecked")
93+
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
94+
assertThat(importanceArray, hasSize(greaterThan(0)));
9195
assertThat(
92-
resultsObject.toString(),
93-
resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
94-
|| resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
95-
is(true));
96+
importanceArray.stream().filter(m -> NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))
97+
|| DISCRETE_NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))).findAny(),
98+
isPresent());
9699
}
97100

98101
assertProgress(jobId, 100, 100, 100, 100);

0 commit comments

Comments
 (0)