Skip to content

[7.x][ML] Stratified cross validation split for classification (#54087) #54104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class Classification implements DataFrameAnalysis {
/**
* The max number of classes classification supports
*/
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;

private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
Expand Down Expand Up @@ -162,7 +161,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
try {
writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(),
writeDataRows(dataExtractor, process, config, task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker());
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
DataCounts::documentId);
Expand Down Expand Up @@ -214,11 +213,12 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
}
}

private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalysis analysis,
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException {
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
DataFrameAnalyticsConfig config, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker)
throws IOException {

CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
.create(analysis);
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(client, config, dataExtractor.getFieldNames())
.create();

// The extra fields are for the doc hash and the control field (should be an empty string)
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
Expand Down Expand Up @@ -324,7 +324,8 @@ private void refreshIndices(String jobId) {
);
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());

LOGGER.debug("[{}] Refreshing indices {}", jobId, Arrays.toString(refreshRequest.indices()));
LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
jobId, Arrays.toString(refreshRequest.indices())));

try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
client.admin().indices().refresh(refreshRequest).actionGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,81 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class CrossValidationSplitterFactory {

private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);

private final Client client;
private final DataFrameAnalyticsConfig config;
private final List<String> fieldNames;

public CrossValidationSplitterFactory(List<String> fieldNames) {
public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
this.client = Objects.requireNonNull(client);
this.config = Objects.requireNonNull(config);
this.fieldNames = Objects.requireNonNull(fieldNames);
}

public CrossValidationSplitter create(DataFrameAnalysis analysis) {
if (analysis instanceof Regression) {
Regression regression = (Regression) analysis;
return new RandomCrossValidationSplitter(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
public CrossValidationSplitter create() {
if (config.getAnalysis() instanceof Regression) {
return createRandomSplitter();
}
if (analysis instanceof Classification) {
Classification classification = (Classification) analysis;
return new RandomCrossValidationSplitter(
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
if (config.getAnalysis() instanceof Classification) {
return createStratifiedSplitter((Classification) config.getAnalysis());
}
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
}

private CrossValidationSplitter createRandomSplitter() {
Regression regression = (Regression) config.getAnalysis();
return new RandomCrossValidationSplitter(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
}

private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
String aggName = "dependent_variable_terms";
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
.setSize(0)
.setAllowPartialSearchResults(false)
.addAggregation(AggregationBuilders.terms(aggName)
.field(classification.getDependentVariable())
.size(Classification.MAX_DEPENDENT_VARIABLE_CARDINALITY));

try {
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
searchRequestBuilder::get);
Aggregations aggs = searchResponse.getAggregations();
Terms terms = aggs.get(aggName);
Map<String, Long> classCardinalities = new HashMap<>();
for (Terms.Bucket bucket : terms.getBuckets()) {
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
}

return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities,
classification.getTrainingPercent(), classification.getRandomizeSeed());
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
LOGGER.error(msg, e);
throw new ElasticsearchException(msg.getFormattedMessage(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter {
private boolean isFirstRow = true;

RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.trainingPercent = trainingPercent;
this.random = new Random(randomizeSeed);
}

private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
for (int i = 0; i < fieldNames.size(); i++) {
if (fieldNames.get(i).equals(dependentVariable)) {
return i;
}
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
return dependentVariableIndex;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
* Given a dependent variable, randomly splits the dataset trying
* to preserve the proportion of each class in the training sample.
*/
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter {

private final int dependentVariableIndex;
private final double samplingRatio;
private final Random random;
private final Map<String, ClassSample> classSamples;

public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities,
double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.samplingRatio = trainingPercent / 100.0;
this.random = new Random(randomizeSeed);
this.classSamples = new HashMap<>();
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue())));
}

private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
return dependentVariableIndex;
}

@Override
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {

if (canBeUsedForTraining(row) == false) {
incrementTestDocs.run();
return;
}

String classValue = row[dependentVariableIndex];
ClassSample sample = classSamples.get(classValue);
if (sample == null) {
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
}

// The idea here is that the probability increases as the chances we have to get the target proportion
// for a class decreases.
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);

boolean isTraining = random.nextDouble() <= p;

sample.observed++;
if (isTraining) {
sample.training++;
incrementTrainingDocs.run();
} else {
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
incrementTestDocs.run();
}
}

private boolean canBeUsedForTraining(String[] row) {
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
}

private static class ClassSample {

private final long cardinality;
private long training;
private long observed;

private ClassSample(long cardinality) {
this.cardinality = cardinality;
}
}
}
Loading