11
11
import org .elasticsearch .action .ActionListener ;
12
12
import org .elasticsearch .action .search .SearchPhaseExecutionException ;
13
13
import org .elasticsearch .action .search .ShardSearchFailure ;
14
- import org .elasticsearch .action .support .SubscribableListener ;
15
14
import org .elasticsearch .action .support .master .AcknowledgedResponse ;
16
15
import org .elasticsearch .cluster .ClusterChangedEvent ;
17
16
import org .elasticsearch .cluster .ClusterName ;
50
49
import java .util .List ;
51
50
import java .util .concurrent .CountDownLatch ;
52
51
import java .util .concurrent .TimeUnit ;
53
- import java .util .concurrent .atomic .AtomicInteger ;
54
- import java .util .function .BiConsumer ;
52
+ import java .util .concurrent .atomic .AtomicReference ;
55
53
56
54
import static org .elasticsearch .xpack .ml .MachineLearning .UTILITY_THREAD_POOL_NAME ;
57
55
import static org .elasticsearch .xpack .ml .inference .assignment .TrainedModelAssignmentClusterServiceTests .shutdownMetadata ;
58
56
import static org .hamcrest .Matchers .equalTo ;
59
- import static org .hamcrest .Matchers .is ;
57
+ import static org .hamcrest .Matchers .notNullValue ;
60
58
import static org .mockito .ArgumentMatchers .any ;
61
59
import static org .mockito .Mockito .doAnswer ;
62
60
import static org .mockito .Mockito .mock ;
@@ -122,41 +120,20 @@ private void loadQueuedModels(TrainedModelAssignmentNodeService trainedModelAssi
122
120
loadQueuedModels (trainedModelAssignmentNodeService , false );
123
121
}
124
122
125
- private void loadQueuedModels (TrainedModelAssignmentNodeService trainedModelAssignmentNodeService , boolean expectedRunImmediately ) {
126
- trainedModelAssignmentNodeService .loadQueuedModels (ActionListener .wrap (actualRunImmediately -> {
127
- assertThat (
128
- "We should rerun immediately if there are still model loading tasks to process." ,
129
- actualRunImmediately ,
130
- equalTo (expectedRunImmediately )
131
- );
132
- }, e -> fail ("We should never call the onFailure method of this listener." )));
133
- }
134
-
135
- private void loadQueuedModels (TrainedModelAssignmentNodeService trainedModelAssignmentNodeService , int times )
123
+ private void loadQueuedModels (TrainedModelAssignmentNodeService trainedModelAssignmentNodeService , boolean expectedRunImmediately )
136
124
throws InterruptedException {
137
- var modelQueueSize = new AtomicInteger (times );
138
- BiConsumer <ActionListener <Object >, Boolean > verifyRerunningImmediately = (listener , result ) -> {
139
- var runImmediately = modelQueueSize .decrementAndGet () > 0 ;
140
- assertThat (
141
- "We should rerun immediately if there are still model loading tasks to process. Models remaining: " + modelQueueSize .get (),
142
- result ,
143
- is (runImmediately )
144
- );
145
- listener .onResponse (null );
146
- };
147
-
148
- var chain = SubscribableListener .newForked (
149
- l -> trainedModelAssignmentNodeService .loadQueuedModels (l .delegateFailure (verifyRerunningImmediately ))
150
- );
151
- for (int i = 1 ; i < times ; i ++) {
152
- chain = chain .andThen (
153
- (l , r ) -> trainedModelAssignmentNodeService .loadQueuedModels (l .delegateFailure (verifyRerunningImmediately ))
154
- );
155
- }
156
-
157
125
var latch = new CountDownLatch (1 );
158
- chain .addListener (ActionListener .running (latch ::countDown ));
126
+ var actual = new AtomicReference <Boolean >(); // AtomicReference for nullable
127
+ trainedModelAssignmentNodeService .loadQueuedModels (
128
+ ActionListener .runAfter (ActionListener .wrap (actual ::set , e -> {}), latch ::countDown )
129
+ );
159
130
assertTrue ("Timed out waiting for loadQueuedModels to finish." , latch .await (10 , TimeUnit .SECONDS ));
131
+ assertThat ("Test failed to call the onResponse handler." , actual .get (), notNullValue ());
132
+ assertThat (
133
+ "We should rerun immediately if there are still model loading tasks to process." ,
134
+ actual .get (),
135
+ equalTo (expectedRunImmediately )
136
+ );
160
137
}
161
138
162
139
public void testLoadQueuedModels () throws InterruptedException {
@@ -237,7 +214,7 @@ public void testLoadQueuedModelsWhenFailureIsRetried() throws InterruptedExcepti
237
214
verifyNoMoreInteractions (deploymentManager , trainedModelAssignmentService );
238
215
}
239
216
240
- public void testLoadQueuedModelsWhenStopped () {
217
+ public void testLoadQueuedModelsWhenStopped () throws InterruptedException {
241
218
TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService ();
242
219
243
220
// When there are no queued models
@@ -247,8 +224,11 @@ public void testLoadQueuedModelsWhenStopped() {
247
224
trainedModelAssignmentNodeService .prepareModelToLoad (newParams (modelToLoad , modelToLoad ));
248
225
trainedModelAssignmentNodeService .stop ();
249
226
250
- trainedModelAssignmentNodeService .loadQueuedModels (
251
- ActionListener .running (() -> fail ("When stopped, then loadQueuedModels should never run." ))
227
+ var latch = new CountDownLatch (1 );
228
+ trainedModelAssignmentNodeService .loadQueuedModels (ActionListener .running (latch ::countDown ));
229
+ assertTrue (
230
+ "loadQueuedModels should immediately call the listener without forking to another thread." ,
231
+ latch .await (0 , TimeUnit .SECONDS )
252
232
);
253
233
verifyNoMoreInteractions (deploymentManager , trainedModelAssignmentService );
254
234
}
0 commit comments