Skip to content

[7.x][ML] Improve resuming a DFA job stopped during inference (#67623) #67669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,6 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
analyticsProcessFactory,
dataFrameAnalyticsAuditor,
trainedModelProvider,
modelLoadingService,
resultsPersisterService,
EsExecutors.allocatedProcessors(settings));
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
Expand All @@ -759,8 +758,9 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client, xContentRegistry,
dataFrameAnalyticsAuditor);
assert client instanceof NodeClient;
DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client, clusterService,
dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, indexNameExpressionResolver);
DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(settings, (NodeClient) client, threadPool,
clusterService, dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor,
indexNameExpressionResolver, resultsPersisterService, modelLoadingService);
this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager);

// Components shared by anomaly detection and data frame analytics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ private void getStartContext(String id, ActionListener<StartContext> finalListen
break;
case RESUMING_REINDEXING:
case RESUMING_ANALYZING:
case RESUMING_INFERENCE:
toValidateMappingsListener.onResponse(startContext);
break;
case FINISHED:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,30 @@
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.MappingMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
import org.elasticsearch.xpack.ml.dataframe.steps.AnalysisStep;
import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep;
import org.elasticsearch.xpack.ml.dataframe.steps.FinalStep;
import org.elasticsearch.xpack.ml.dataframe.steps.InferenceStep;
import org.elasticsearch.xpack.ml.dataframe.steps.ReindexingStep;
import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -43,27 +53,36 @@ public class DataFrameAnalyticsManager {

private static final Logger LOGGER = LogManager.getLogger(DataFrameAnalyticsManager.class);

private final Settings settings;
/**
* We need a {@link NodeClient} to get the reindexing task and be able to report progress
*/
private final NodeClient client;
private final ThreadPool threadPool;
private final ClusterService clusterService;
private final DataFrameAnalyticsConfigProvider configProvider;
private final AnalyticsProcessManager processManager;
private final DataFrameAnalyticsAuditor auditor;
private final IndexNameExpressionResolver expressionResolver;
private final ResultsPersisterService resultsPersisterService;
private final ModelLoadingService modelLoadingService;
/** Indicates whether the node is shutting down. */
private final AtomicBoolean nodeShuttingDown = new AtomicBoolean();

public DataFrameAnalyticsManager(NodeClient client, ClusterService clusterService, DataFrameAnalyticsConfigProvider configProvider,
AnalyticsProcessManager processManager, DataFrameAnalyticsAuditor auditor,
IndexNameExpressionResolver expressionResolver) {
public DataFrameAnalyticsManager(Settings settings, NodeClient client, ThreadPool threadPool, ClusterService clusterService,
DataFrameAnalyticsConfigProvider configProvider, AnalyticsProcessManager processManager,
DataFrameAnalyticsAuditor auditor, IndexNameExpressionResolver expressionResolver,
ResultsPersisterService resultsPersisterService, ModelLoadingService modelLoadingService) {
this.settings = Objects.requireNonNull(settings);
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
this.clusterService = Objects.requireNonNull(clusterService);
this.configProvider = Objects.requireNonNull(configProvider);
this.processManager = Objects.requireNonNull(processManager);
this.auditor = Objects.requireNonNull(auditor);
this.expressionResolver = Objects.requireNonNull(expressionResolver);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
}

public void execute(DataFrameAnalyticsTask task, ClusterState clusterState) {
Expand Down Expand Up @@ -141,6 +160,12 @@ private void determineProgressAndResume(DataFrameAnalyticsTask task, DataFrameAn
case RESUMING_ANALYZING:
executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager));
break;
case RESUMING_INFERENCE:
buildInferenceStep(task, config, ActionListener.wrap(
inferenceStep -> executeStep(task, config, inferenceStep),
task::setFailed
));
break;
case FINISHED:
default:
task.setFailed(ExceptionsHelper.serverError("Unexpected starting state [" + startingState + "]"));
Expand All @@ -162,7 +187,15 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c
executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager));
break;
case ANALYSIS:
// This is the last step
buildInferenceStep(task, config, ActionListener.wrap(
inferenceStep -> executeStep(task, config, inferenceStep),
task::setFailed
));
break;
case INFERENCE:
executeStep(task, config, new FinalStep(client, task, auditor, config));
break;
case FINAL:
LOGGER.info("[{}] Marking task completed", config.getId());
task.markAsCompleted();
break;
Expand Down Expand Up @@ -199,6 +232,24 @@ private void executeJobInMiddleOfReindexing(DataFrameAnalyticsTask task, DataFra
));
}

private void buildInferenceStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, ActionListener<InferenceStep> listener) {
ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());

ActionListener<ExtractedFieldsDetector> extractedFieldsDetectorListener = ActionListener.wrap(
extractedFieldsDetector -> {
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService,
resultsPersisterService, task.getParentTaskId(), config, extractedFields, task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker());
InferenceStep inferenceStep = new InferenceStep(client, task, auditor, config, threadPool, inferenceRunner);
listener.onResponse(inferenceStep);
},
listener::onFailure
);

new ExtractedFieldsDetectorFactory(parentTaskClient).createFromDest(config, extractedFieldsDetectorListener);
}

public boolean isNodeShuttingDown() {
return nodeShuttingDown.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ public void updateTaskProgress(ActionListener<Void> updateProgressListener) {
* {@code FINISHED} means the job had finished.
*/
public enum StartingState {
FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, FINISHED
FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, RESUMING_INFERENCE, FINISHED
}

public StartingState determineStartingState() {
Expand All @@ -313,6 +313,9 @@ public static StartingState determineStartingState(String jobId, List<PhaseProgr
if (ProgressTracker.REINDEXING.equals(lastIncompletePhase.getPhase())) {
return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
}
if (ProgressTracker.INFERENCE.equals(lastIncompletePhase.getPhase())) {
return StartingState.RESUMING_INFERENCE;
}
return StartingState.RESUMING_ANALYZING;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public boolean isCancelled() {
}

public void cancel() {
LOGGER.debug("[{}] Data extractor was cancelled", context.jobId);
LOGGER.debug(() -> new ParameterizedMessage("[{}] Data extractor was cancelled", context.jobId));
isCancelled = true;
}

Expand Down Expand Up @@ -127,7 +127,7 @@ private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request)
// We've set allow_partial_search_results to false which means if something
// goes wrong the request will throw.
SearchResponse searchResponse = request.get();
LOGGER.debug("[{}] Search response was obtained", context.jobId);
LOGGER.trace(() -> new ParameterizedMessage("[{}] Search response was obtained", context.jobId));

List<Row> rows = processSearchResponse(searchResponse);

Expand All @@ -153,7 +153,7 @@ private SearchRequestBuilder buildSearchRequest() {
long from = lastSortKey + 1;
long to = from + context.scrollSize;

LOGGER.debug(() -> new ParameterizedMessage(
LOGGER.trace(() -> new ParameterizedMessage(
"[{}] Searching docs with [{}] in [{}, {})", context.jobId, DestinationIndex.INCREMENTAL_ID, from, to));

SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
Expand Down Expand Up @@ -283,7 +283,7 @@ private Row createRow(SearchHit hit) {
}
boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
Row row = new Row(extractedValues, hit, isTraining);
LOGGER.debug(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}",
LOGGER.trace(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}",
context.jobId, row.getSortKey(), isTraining, Arrays.toString(row.values)));
return row;
}
Expand All @@ -306,7 +306,7 @@ public DataSummary collectDataSummary() {
SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
long rows = searchResponse.getHits().getTotalHits().value;
LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows);
LOGGER.debug(() -> new ParameterizedMessage("[{}] Data summary rows [{}]", context.jobId, rows));
return new DataSummary(rows, organicFeatures.length + processedFeatures.length);
}

Expand Down
Loading