Skip to content

Commit cc4bc79

Browse files
authored
[7.x] Implement precision and recall metrics for classification evaluation (#49671) (#50378)
1 parent 9033052 commit cc4bc79

File tree

54 files changed

+2487
-382
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2487
-382
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.Objects;
3636
import java.util.stream.Collectors;
3737

38+
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
3839
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3940

4041
public class EvaluateDataFrameResponse implements ToXContentObject {
@@ -47,7 +48,7 @@ public static EvaluateDataFrameResponse fromXContent(XContentParser parser) thro
4748
ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
4849
String evaluationName = parser.currentName();
4950
parser.nextToken();
50-
Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric);
51+
Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, p -> parseMetric(evaluationName, p));
5152
List<EvaluationMetric.Result> knownMetrics =
5253
metrics.values().stream()
5354
.filter(Objects::nonNull) // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}.
@@ -56,10 +57,10 @@ public static EvaluateDataFrameResponse fromXContent(XContentParser parser) thro
5657
return new EvaluateDataFrameResponse(evaluationName, knownMetrics);
5758
}
5859

59-
private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException {
60+
private static EvaluationMetric.Result parseMetric(String evaluationName, XContentParser parser) throws IOException {
6061
String metricName = parser.currentName();
6162
try {
62-
return parser.namedObject(EvaluationMetric.Result.class, metricName, null);
63+
return parser.namedObject(EvaluationMetric.Result.class, registeredMetricName(evaluationName, metricName), null);
6364
} catch (NamedObjectNotFoundException e) {
6465
parser.skipChildren();
6566
// Metric name not recognized. Return {@code null} value here and filter it out later.

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,36 @@
2020

2121
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
2222
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
23-
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
2423
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
24+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
2525
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
2626
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
27-
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
28-
import org.elasticsearch.common.ParseField;
29-
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
30-
import org.elasticsearch.plugins.spi.NamedXContentProvider;
3127
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
28+
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
3229
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
3330
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
3431
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
32+
import org.elasticsearch.common.ParseField;
33+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
34+
import org.elasticsearch.plugins.spi.NamedXContentProvider;
3535

3636
import java.util.Arrays;
3737
import java.util.List;
3838

3939
public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
4040

41+
/**
42+
* Constructs the name under which a metric (or metric result) is registered.
43+
* The name is prefixed with evaluation name so that registered names are unique.
44+
*
45+
* @param evaluationName name of the evaluation
46+
* @param metricName name of the metric
47+
* @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
48+
*/
49+
public static String registeredMetricName(String evaluationName, String metricName) {
50+
return evaluationName + "." + metricName;
51+
}
52+
4153
@Override
4254
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
4355
return Arrays.asList(
@@ -47,39 +59,91 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass
4759
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
4860
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
4961
// Evaluation metrics
50-
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
51-
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
52-
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
5362
new NamedXContentRegistry.Entry(
54-
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
63+
EvaluationMetric.class,
64+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
65+
AucRocMetric::fromXContent),
5566
new NamedXContentRegistry.Entry(
56-
EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
67+
EvaluationMetric.class,
68+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
69+
PrecisionMetric::fromXContent),
5770
new NamedXContentRegistry.Entry(
5871
EvaluationMetric.class,
59-
new ParseField(MulticlassConfusionMatrixMetric.NAME),
72+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
73+
RecallMetric::fromXContent),
74+
new NamedXContentRegistry.Entry(
75+
EvaluationMetric.class,
76+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
77+
ConfusionMatrixMetric::fromXContent),
78+
new NamedXContentRegistry.Entry(
79+
EvaluationMetric.class,
80+
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
81+
AccuracyMetric::fromXContent),
82+
new NamedXContentRegistry.Entry(
83+
EvaluationMetric.class,
84+
new ParseField(registeredMetricName(
85+
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
86+
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
87+
new NamedXContentRegistry.Entry(
88+
EvaluationMetric.class,
89+
new ParseField(registeredMetricName(
90+
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
91+
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
92+
new NamedXContentRegistry.Entry(
93+
EvaluationMetric.class,
94+
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
6095
MulticlassConfusionMatrixMetric::fromXContent),
6196
new NamedXContentRegistry.Entry(
62-
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
97+
EvaluationMetric.class,
98+
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
99+
MeanSquaredErrorMetric::fromXContent),
63100
new NamedXContentRegistry.Entry(
64-
EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent),
101+
EvaluationMetric.class,
102+
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
103+
RSquaredMetric::fromXContent),
65104
// Evaluation metrics results
66105
new NamedXContentRegistry.Entry(
67-
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
106+
EvaluationMetric.Result.class,
107+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
108+
AucRocMetric.Result::fromXContent),
109+
new NamedXContentRegistry.Entry(
110+
EvaluationMetric.Result.class,
111+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
112+
PrecisionMetric.Result::fromXContent),
68113
new NamedXContentRegistry.Entry(
69-
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
114+
EvaluationMetric.Result.class,
115+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
116+
RecallMetric.Result::fromXContent),
70117
new NamedXContentRegistry.Entry(
71-
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
118+
EvaluationMetric.Result.class,
119+
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
120+
ConfusionMatrixMetric.Result::fromXContent),
72121
new NamedXContentRegistry.Entry(
73-
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
122+
EvaluationMetric.Result.class,
123+
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
124+
AccuracyMetric.Result::fromXContent),
74125
new NamedXContentRegistry.Entry(
75-
EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
126+
EvaluationMetric.Result.class,
127+
new ParseField(registeredMetricName(
128+
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
129+
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
76130
new NamedXContentRegistry.Entry(
77131
EvaluationMetric.Result.class,
78-
new ParseField(MulticlassConfusionMatrixMetric.NAME),
132+
new ParseField(registeredMetricName(
133+
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
134+
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
135+
new NamedXContentRegistry.Entry(
136+
EvaluationMetric.Result.class,
137+
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
79138
MulticlassConfusionMatrixMetric.Result::fromXContent),
80139
new NamedXContentRegistry.Entry(
81-
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
140+
EvaluationMetric.Result.class,
141+
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
142+
MeanSquaredErrorMetric.Result::fromXContent),
82143
new NamedXContentRegistry.Entry(
83-
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
144+
EvaluationMetric.Result.class,
145+
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
146+
RSquaredMetric.Result::fromXContent)
147+
);
84148
}
85149
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
import java.util.List;
3333
import java.util.Objects;
3434

35+
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
36+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
37+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
38+
3539
/**
3640
* Evaluation of classification results.
3741
*/
@@ -48,10 +52,10 @@ public class Classification implements Evaluation {
4852
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
4953

5054
static {
51-
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
52-
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
53-
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
54-
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
55+
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
56+
PARSER.declareString(constructorArg(), PREDICTED_FIELD);
57+
PARSER.declareNamedObjects(
58+
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
5559
}
5660

5761
public static Classification fromXContent(XContentParser parser) {

0 commit comments

Comments
 (0)