Skip to content

[7.x] Fix accuracy metric (#50310) #50433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
Expand All @@ -35,10 +36,25 @@
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;

/**
* {@link AccuracyMetric} is a metric that answers the question:
* "What fraction of examples have been classified correctly by the classifier?"
* {@link AccuracyMetric} is a metric that answers the following two questions:
*
* equation: accuracy = 1/n * Σ(y == y´)
* 1. What is the fraction of documents for which predicted class equals the actual class?
*
* equation: overall_accuracy = 1/n * Σ(y == y')
* where: n = total number of documents
* y = document's actual class
* y' = document's predicted class
*
* 2. For any given class X, what is the fraction of documents for which either
* a) both actual and predicted class are equal to X (true positives)
* or
* b) both actual and predicted class are not equal to X (true negatives)
*
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
* where: X = class being examined
* n = total number of documents
* TP(X) = number of true positives wrt X
* TN(X) = number of true negatives wrt X
*/
public class AccuracyMetric implements EvaluationMetric {

Expand Down Expand Up @@ -78,29 +94,29 @@ public int hashCode() {

public static class Result implements EvaluationMetric.Result {

private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));

static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}

public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

/** List of actual classes. */
private final List<ActualClass> actualClasses;
/** Fraction of documents predicted correctly. */
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy;

public Result(List<ActualClass> actualClasses, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
public Result(List<PerClassResult> classes, double overallAccuracy) {
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.overallAccuracy = overallAccuracy;
}

Expand All @@ -109,8 +125,8 @@ public String getMetricName() {
return NAME;
}

public List<ActualClass> getActualClasses() {
return actualClasses;
public List<PerClassResult> getClasses() {
return classes;
}

public double getOverallAccuracy() {
Expand All @@ -120,7 +136,7 @@ public double getOverallAccuracy() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
builder.field(CLASSES.getPreferredName(), classes);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject();
return builder;
Expand All @@ -131,52 +147,42 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses)
return Objects.equals(this.classes, that.classes)
&& this.overallAccuracy == that.overallAccuracy;
}

@Override
public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy);
return Objects.hash(classes, overallAccuracy);
}
}

public static class ActualClass implements ToXContentObject {
public static class PerClassResult implements ToXContentObject {

private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACCURACY = new ParseField("accuracy");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));

static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), ACCURACY);
}

/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
/** Name of the class. */
private final String className;
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final double accuracy;

public ActualClass(
String actualClass, long actualClassDocCount, double accuracy) {
this.actualClass = Objects.requireNonNull(actualClass);
this.actualClassDocCount = actualClassDocCount;
public PerClassResult(String className, double accuracy) {
this.className = Objects.requireNonNull(className);
this.accuracy = accuracy;
}

public String getActualClass() {
return actualClass;
}

public long getActualClassDocCount() {
return actualClassDocCount;
public String getClassName() {
return className;
}

public double getAccuracy() {
Expand All @@ -186,8 +192,7 @@ public double getAccuracy() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
Expand All @@ -197,15 +202,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.accuracy == that.accuracy;
}

@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy);
return Objects.hash(className, accuracy);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1849,15 +1849,15 @@ public void testEvaluateDataFrame_Classification() throws IOException {
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat(
accuracyResult.getActualClasses(),
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly
new AccuracyMetric.ActualClass("cat", 5, 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75),
// no examples labeled as "ant" were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
// 9 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("ant", 0.9),
// 6 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("cat", 0.6),
// 8 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("dog", 0.8))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
}
{ // Precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification;

import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
Expand All @@ -41,13 +41,13 @@ protected NamedXContentRegistry xContentRegistry() {
public static Result randomResult() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
classes.add(new PerClassResult(classNames.get(i), accuracy));
}
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy);
return new Result(classes, overallAccuracy);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
* Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
*/
Optional<EvaluationMetricResult> getResult();
Optional<? extends EvaluationMetricResult> getResult();
}
Loading