Skip to content

Commit cbebc26

Browse files
[7.x][ML] Retry persisting DF Analytics results (#52048) (#52160)
Employs `ResultsPersisterService` from `DataFrameRowsJoiner` in order to add retries when a data frame analytics job is persisting the results to the destination data frame. Backport of #52048
1 parent 2f1631d commit cbebc26

File tree

6 files changed

+64
-46
lines changed

6 files changed

+64
-46
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
634634

635635
// Data frame analytics components
636636
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
637-
dataFrameAnalyticsAuditor, trainedModelProvider);
637+
dataFrameAnalyticsAuditor, trainedModelProvider, resultsPersisterService);
638638
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
639639
new MemoryUsageEstimationProcessManager(
640640
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
3939
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
4040
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
41+
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
4142

4243
import java.io.IOException;
4344
import java.util.List;
@@ -62,19 +63,22 @@ public class AnalyticsProcessManager {
6263
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
6364
private final DataFrameAnalyticsAuditor auditor;
6465
private final TrainedModelProvider trainedModelProvider;
66+
private final ResultsPersisterService resultsPersisterService;
6567

6668
public AnalyticsProcessManager(Client client,
6769
ThreadPool threadPool,
6870
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
6971
DataFrameAnalyticsAuditor auditor,
70-
TrainedModelProvider trainedModelProvider) {
72+
TrainedModelProvider trainedModelProvider,
73+
ResultsPersisterService resultsPersisterService) {
7174
this(
7275
client,
7376
threadPool.generic(),
7477
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
7578
analyticsProcessFactory,
7679
auditor,
77-
trainedModelProvider);
80+
trainedModelProvider,
81+
resultsPersisterService);
7882
}
7983

8084
// Visible for testing
@@ -83,13 +87,15 @@ public AnalyticsProcessManager(Client client,
8387
ExecutorService executorServiceForProcess,
8488
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
8589
DataFrameAnalyticsAuditor auditor,
86-
TrainedModelProvider trainedModelProvider) {
90+
TrainedModelProvider trainedModelProvider,
91+
ResultsPersisterService resultsPersisterService) {
8792
this.client = Objects.requireNonNull(client);
8893
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
8994
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
9095
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
9196
this.auditor = Objects.requireNonNull(auditor);
9297
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
98+
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
9399
}
94100

95101
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) {
@@ -419,7 +425,7 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataEx
419425
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
420426
DataFrameDataExtractorFactory dataExtractorFactory) {
421427
DataFrameRowsJoiner dataFrameRowsJoiner =
422-
new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true));
428+
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
423429
return new AnalyticsResultProcessor(
424430
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
425431
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99
import org.apache.logging.log4j.Logger;
1010
import org.apache.logging.log4j.message.ParameterizedMessage;
1111
import org.elasticsearch.action.DocWriteRequest;
12-
import org.elasticsearch.action.bulk.BulkAction;
1312
import org.elasticsearch.action.bulk.BulkRequest;
14-
import org.elasticsearch.action.bulk.BulkResponse;
1513
import org.elasticsearch.action.index.IndexRequest;
16-
import org.elasticsearch.client.Client;
1714
import org.elasticsearch.common.Nullable;
1815
import org.elasticsearch.search.SearchHit;
19-
import org.elasticsearch.xpack.core.ClientHelper;
2016
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2117
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
2218
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
19+
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
2320

2421
import java.io.IOException;
2522
import java.util.Collections;
@@ -38,16 +35,17 @@ class DataFrameRowsJoiner implements AutoCloseable {
3835
private static final int RESULTS_BATCH_SIZE = 1000;
3936

4037
private final String analyticsId;
41-
private final Client client;
4238
private final DataFrameDataExtractor dataExtractor;
39+
private final ResultsPersisterService resultsPersisterService;
4340
private final Iterator<DataFrameDataExtractor.Row> dataFrameRowsIterator;
4441
private LinkedList<RowResults> currentResults;
4542
private volatile String failure;
4643

47-
DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) {
44+
DataFrameRowsJoiner(String analyticsId, DataFrameDataExtractor dataExtractor,
45+
ResultsPersisterService resultsPersisterService) {
4846
this.analyticsId = Objects.requireNonNull(analyticsId);
49-
this.client = Objects.requireNonNull(client);
5047
this.dataExtractor = Objects.requireNonNull(dataExtractor);
48+
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
5149
this.dataFrameRowsIterator = new ResultMatchingDataFrameRows();
5250
this.currentResults = new LinkedList<>();
5351
}
@@ -88,7 +86,8 @@ private void joinCurrentResults() {
8886
bulkRequest.add(createIndexRequest(result, row.getHit()));
8987
}
9088
if (bulkRequest.numberOfActions() > 0) {
91-
executeBulkRequest(bulkRequest);
89+
resultsPersisterService.bulkIndexWithHeadersWithRetry(
90+
dataExtractor.getHeaders(), bulkRequest, analyticsId, () -> true, errorMsg -> {});
9291
}
9392
currentResults = new LinkedList<>();
9493
}
@@ -113,14 +112,6 @@ private IndexRequest createIndexRequest(RowResults result, SearchHit hit) {
113112
return indexRequest;
114113
}
115114

116-
private void executeBulkRequest(BulkRequest bulkRequest) {
117-
BulkResponse bulkResponse = ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), ClientHelper.ML_ORIGIN, client,
118-
() -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet());
119-
if (bulkResponse.hasFailures()) {
120-
throw ExceptionsHelper.serverError("failures while writing results [" + bulkResponse.buildFailureMessage() + "]");
121-
}
122-
}
123-
124115
@Override
125116
public void close() {
126117
try {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.apache.logging.log4j.Logger;
1010
import org.apache.logging.log4j.message.ParameterizedMessage;
1111
import org.elasticsearch.ElasticsearchException;
12+
import org.elasticsearch.action.bulk.BulkAction;
1213
import org.elasticsearch.action.bulk.BulkItemResponse;
1314
import org.elasticsearch.action.bulk.BulkRequest;
1415
import org.elasticsearch.action.bulk.BulkResponse;
@@ -27,13 +28,16 @@
2728
import org.elasticsearch.common.xcontent.XContentBuilder;
2829
import org.elasticsearch.common.xcontent.XContentFactory;
2930
import org.elasticsearch.rest.RestStatus;
31+
import org.elasticsearch.xpack.core.ClientHelper;
3032

3133
import java.io.IOException;
3234
import java.time.Duration;
3335
import java.util.Arrays;
36+
import java.util.Map;
3437
import java.util.Random;
3538
import java.util.Set;
3639
import java.util.function.Consumer;
40+
import java.util.function.Function;
3741
import java.util.function.Supplier;
3842
import java.util.stream.Collectors;
3943

@@ -95,9 +99,28 @@ public BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
9599
String jobId,
96100
Supplier<Boolean> shouldRetry,
97101
Consumer<String> msgHandler) {
102+
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
103+
providedBulkRequest -> client.bulk(providedBulkRequest).actionGet());
104+
}
105+
106+
public BulkResponse bulkIndexWithHeadersWithRetry(Map<String, String> headers,
107+
BulkRequest bulkRequest,
108+
String jobId,
109+
Supplier<Boolean> shouldRetry,
110+
Consumer<String> msgHandler) {
111+
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
112+
providedBulkRequest -> ClientHelper.executeWithHeaders(
113+
headers, ClientHelper.ML_ORIGIN, client, () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()));
114+
}
115+
116+
private BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
117+
String jobId,
118+
Supplier<Boolean> shouldRetry,
119+
Consumer<String> msgHandler,
120+
Function<BulkRequest, BulkResponse> actionExecutor) {
98121
RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler);
99122
while (true) {
100-
BulkResponse bulkResponse = client.bulk(bulkRequest).actionGet();
123+
BulkResponse bulkResponse = actionExecutor.apply(bulkRequest);
101124
if (bulkResponse.hasFailures() == false) {
102125
return bulkResponse;
103126
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
2424
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
2525
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
26+
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
2627
import org.junit.Before;
2728
import org.mockito.InOrder;
2829

@@ -65,6 +66,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
6566
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
6667
private DataFrameDataExtractorFactory dataExtractorFactory;
6768
private DataFrameDataExtractor dataExtractor;
69+
private ResultsPersisterService resultsPersisterService;
6870
private AnalyticsProcessManager processManager;
6971

7072
@SuppressWarnings("unchecked")
@@ -97,8 +99,10 @@ public void setUpMocks() {
9799
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
98100
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
99101

100-
processManager = new AnalyticsProcessManager(
101-
client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider);
102+
resultsPersisterService = mock(ResultsPersisterService.class);
103+
104+
processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
105+
trainedModelProvider, resultsPersisterService);
102106
}
103107

104108
public void testRunJob_TaskIsStopping() {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,17 @@
55
*/
66
package org.elasticsearch.xpack.ml.dataframe.process;
77

8-
import org.elasticsearch.action.ActionFuture;
9-
import org.elasticsearch.action.bulk.BulkAction;
108
import org.elasticsearch.action.bulk.BulkItemResponse;
119
import org.elasticsearch.action.bulk.BulkRequest;
1210
import org.elasticsearch.action.bulk.BulkResponse;
1311
import org.elasticsearch.action.index.IndexRequest;
14-
import org.elasticsearch.client.Client;
1512
import org.elasticsearch.common.bytes.BytesArray;
16-
import org.elasticsearch.common.settings.Settings;
1713
import org.elasticsearch.common.text.Text;
18-
import org.elasticsearch.common.util.concurrent.ThreadContext;
1914
import org.elasticsearch.search.SearchHit;
2015
import org.elasticsearch.test.ESTestCase;
21-
import org.elasticsearch.threadpool.ThreadPool;
2216
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
2317
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
18+
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
2419
import org.junit.Before;
2520
import org.mockito.ArgumentCaptor;
2621

@@ -35,7 +30,8 @@
3530
import java.util.stream.IntStream;
3631

3732
import static org.hamcrest.Matchers.equalTo;
38-
import static org.mockito.Matchers.same;
33+
import static org.mockito.Matchers.any;
34+
import static org.mockito.Matchers.eq;
3935
import static org.mockito.Mockito.mock;
4036
import static org.mockito.Mockito.times;
4137
import static org.mockito.Mockito.verify;
@@ -46,19 +42,22 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
4642

4743
private static final String ANALYTICS_ID = "my_analytics";
4844

49-
private Client client;
45+
private static final Map<String, String> HEADERS = Collections.singletonMap("foo", "bar");
46+
5047
private DataFrameDataExtractor dataExtractor;
48+
private ResultsPersisterService resultsPersisterService;
5149
private ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
5250

5351
@Before
5452
public void setUpMocks() {
55-
client = mock(Client.class);
5653
dataExtractor = mock(DataFrameDataExtractor.class);
54+
when(dataExtractor.getHeaders()).thenReturn(HEADERS);
55+
resultsPersisterService = mock(ResultsPersisterService.class);
5756
}
5857

5958
public void testProcess_GivenNoResults() {
6059
givenProcessResults(Collections.emptyList());
61-
verifyNoMoreInteractions(client);
60+
verifyNoMoreInteractions(resultsPersisterService);
6261
}
6362

6463
public void testProcess_GivenSingleRowAndResult() throws IOException {
@@ -126,7 +125,7 @@ public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IO
126125
RowResults result = new RowResults(2, resultFields);
127126
givenProcessResults(Arrays.asList(result));
128127

129-
verifyNoMoreInteractions(client);
128+
verifyNoMoreInteractions(resultsPersisterService);
130129
}
131130

132131
public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException {
@@ -204,7 +203,7 @@ public void testProcess_GivenMoreResultsThanRows() throws IOException {
204203
RowResults result2 = new RowResults(2, resultFields);
205204
givenProcessResults(Arrays.asList(result1, result2));
206205

207-
verifyNoMoreInteractions(client);
206+
verifyNoMoreInteractions(resultsPersisterService);
208207
}
209208

210209
public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
@@ -218,13 +217,13 @@ public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws
218217

219218
givenProcessResults(Collections.emptyList());
220219

221-
verifyNoMoreInteractions(client);
220+
verifyNoMoreInteractions(resultsPersisterService);
222221
verify(dataExtractor).cancel();
223222
verify(dataExtractor, times(2)).next();
224223
}
225224

226225
private void givenProcessResults(List<RowResults> results) {
227-
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor)) {
226+
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, dataExtractor, resultsPersisterService)) {
228227
results.forEach(joiner::processRowResults);
229228
}
230229
}
@@ -251,14 +250,9 @@ private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values,
251250
}
252251

253252
private void givenClientHasNoFailures() {
254-
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
255-
ThreadPool threadPool = mock(ThreadPool.class);
256-
when(threadPool.getThreadContext()).thenReturn(threadContext);
257-
@SuppressWarnings("unchecked")
258-
ActionFuture<BulkResponse> responseFuture = mock(ActionFuture.class);
259-
when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
260-
when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture);
261-
when(client.threadPool()).thenReturn(threadPool);
253+
when(resultsPersisterService.bulkIndexWithHeadersWithRetry(
254+
eq(HEADERS), bulkRequestCaptor.capture(), eq(ANALYTICS_ID), any(), any()))
255+
.thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
262256
}
263257

264258
private static class DelegateStubDataExtractor {

0 commit comments

Comments
 (0)