Skip to content

Commit 27e6b37

Browse files
authored
[ML] Wait for test to finish (#110542)
The tests can kick off tasks on another thread. We should wait for those threads to join back before we begin making assertions. Fix #110536
1 parent 5e09657 commit 27e6b37

File tree

3 files changed

+20
-42
lines changed

3 files changed

+20
-42
lines changed

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,6 @@ tests:
100100
- class: org.elasticsearch.test.rest.yaml.CcsCommonYamlTestSuiteIT
101101
method: test {p0=search.vectors/41_knn_search_half_byte_quantized/Test create, merge, and search cosine}
102102
issue: https://github.com/elastic/elasticsearch/issues/109978
103-
- class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentNodeServiceTests
104-
method: testLoadQueuedModelsWhenOneFails
105-
issue: https://github.com/elastic/elasticsearch/issues/110536
106103

107104
# Examples:
108105
#

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ void stop() {
184184

185185
void loadQueuedModels(ActionListener<Boolean> rescheduleImmediately) {
186186
if (stopped) {
187+
rescheduleImmediately.onResponse(false);
187188
return;
188189
}
189190
if (latestState != null) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.search.SearchPhaseExecutionException;
1313
import org.elasticsearch.action.search.ShardSearchFailure;
14-
import org.elasticsearch.action.support.SubscribableListener;
1514
import org.elasticsearch.action.support.master.AcknowledgedResponse;
1615
import org.elasticsearch.cluster.ClusterChangedEvent;
1716
import org.elasticsearch.cluster.ClusterName;
@@ -50,13 +49,12 @@
5049
import java.util.List;
5150
import java.util.concurrent.CountDownLatch;
5251
import java.util.concurrent.TimeUnit;
53-
import java.util.concurrent.atomic.AtomicInteger;
54-
import java.util.function.BiConsumer;
52+
import java.util.concurrent.atomic.AtomicReference;
5553

5654
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
5755
import static org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterServiceTests.shutdownMetadata;
5856
import static org.hamcrest.Matchers.equalTo;
59-
import static org.hamcrest.Matchers.is;
57+
import static org.hamcrest.Matchers.notNullValue;
6058
import static org.mockito.ArgumentMatchers.any;
6159
import static org.mockito.Mockito.doAnswer;
6260
import static org.mockito.Mockito.mock;
@@ -122,41 +120,20 @@ private void loadQueuedModels(TrainedModelAssignmentNodeService trainedModelAssi
122120
loadQueuedModels(trainedModelAssignmentNodeService, false);
123121
}
124122

125-
private void loadQueuedModels(TrainedModelAssignmentNodeService trainedModelAssignmentNodeService, boolean expectedRunImmediately) {
126-
trainedModelAssignmentNodeService.loadQueuedModels(ActionListener.wrap(actualRunImmediately -> {
127-
assertThat(
128-
"We should rerun immediately if there are still model loading tasks to process.",
129-
actualRunImmediately,
130-
equalTo(expectedRunImmediately)
131-
);
132-
}, e -> fail("We should never call the onFailure method of this listener.")));
133-
}
134-
135-
private void loadQueuedModels(TrainedModelAssignmentNodeService trainedModelAssignmentNodeService, int times)
123+
private void loadQueuedModels(TrainedModelAssignmentNodeService trainedModelAssignmentNodeService, boolean expectedRunImmediately)
136124
throws InterruptedException {
137-
var modelQueueSize = new AtomicInteger(times);
138-
BiConsumer<ActionListener<Object>, Boolean> verifyRerunningImmediately = (listener, result) -> {
139-
var runImmediately = modelQueueSize.decrementAndGet() > 0;
140-
assertThat(
141-
"We should rerun immediately if there are still model loading tasks to process. Models remaining: " + modelQueueSize.get(),
142-
result,
143-
is(runImmediately)
144-
);
145-
listener.onResponse(null);
146-
};
147-
148-
var chain = SubscribableListener.newForked(
149-
l -> trainedModelAssignmentNodeService.loadQueuedModels(l.delegateFailure(verifyRerunningImmediately))
150-
);
151-
for (int i = 1; i < times; i++) {
152-
chain = chain.andThen(
153-
(l, r) -> trainedModelAssignmentNodeService.loadQueuedModels(l.delegateFailure(verifyRerunningImmediately))
154-
);
155-
}
156-
157125
var latch = new CountDownLatch(1);
158-
chain.addListener(ActionListener.running(latch::countDown));
126+
var actual = new AtomicReference<Boolean>(); // AtomicReference for nullable
127+
trainedModelAssignmentNodeService.loadQueuedModels(
128+
ActionListener.runAfter(ActionListener.wrap(actual::set, e -> {}), latch::countDown)
129+
);
159130
assertTrue("Timed out waiting for loadQueuedModels to finish.", latch.await(10, TimeUnit.SECONDS));
131+
assertThat("Test failed to call the onResponse handler.", actual.get(), notNullValue());
132+
assertThat(
133+
"We should rerun immediately if there are still model loading tasks to process.",
134+
actual.get(),
135+
equalTo(expectedRunImmediately)
136+
);
160137
}
161138

162139
public void testLoadQueuedModels() throws InterruptedException {
@@ -237,7 +214,7 @@ public void testLoadQueuedModelsWhenFailureIsRetried() throws InterruptedExcepti
237214
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
238215
}
239216

240-
public void testLoadQueuedModelsWhenStopped() {
217+
public void testLoadQueuedModelsWhenStopped() throws InterruptedException {
241218
TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
242219

243220
// When there are no queued models
@@ -247,8 +224,11 @@ public void testLoadQueuedModelsWhenStopped() {
247224
trainedModelAssignmentNodeService.prepareModelToLoad(newParams(modelToLoad, modelToLoad));
248225
trainedModelAssignmentNodeService.stop();
249226

250-
trainedModelAssignmentNodeService.loadQueuedModels(
251-
ActionListener.running(() -> fail("When stopped, then loadQueuedModels should never run."))
227+
var latch = new CountDownLatch(1);
228+
trainedModelAssignmentNodeService.loadQueuedModels(ActionListener.running(latch::countDown));
229+
assertTrue(
230+
"loadQueuedModels should immediately call the listener without forking to another thread.",
231+
latch.await(0, TimeUnit.SECONDS)
252232
);
253233
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
254234
}

0 commit comments

Comments
 (0)