33
33
import java .util .Iterator ;
34
34
import java .util .List ;
35
35
import java .util .concurrent .ConcurrentHashMap ;
36
+ import java .util .concurrent .Phaser ;
36
37
import java .util .stream .Collectors ;
37
38
38
39
/**
@@ -56,6 +57,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
56
57
private final ClusterService clusterService ;
57
58
private final JobManager jobManager ;
58
59
private final JobResultsProvider jobResultsProvider ;
60
+ private final Phaser stopPhaser ;
59
61
private volatile boolean isMaster ;
60
62
private volatile Instant lastUpdateTime ;
61
63
private volatile Duration reassignmentRecheckInterval ;
@@ -66,6 +68,7 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP
66
68
this .clusterService = clusterService ;
67
69
this .jobManager = jobManager ;
68
70
this .jobResultsProvider = jobResultsProvider ;
71
+ this .stopPhaser = new Phaser (1 );
69
72
setReassignmentRecheckInterval (PersistentTasksClusterService .CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING .get (settings ));
70
73
clusterService .addLocalNodeMasterListener (this );
71
74
clusterService .getClusterSettings ().addSettingsUpdateConsumer (
@@ -90,6 +93,23 @@ public void offMaster() {
90
93
lastUpdateTime = null ;
91
94
}
92
95
96
+ /**
97
+ * Wait for all outstanding searches to complete.
98
+ * After returning, no new searches can be started.
99
+ */
100
+ public void stop () {
101
+ logger .trace ("ML memory tracker stop called" );
102
+ // We never terminate the phaser
103
+ assert stopPhaser .isTerminated () == false ;
104
+ // If there are no registered parties or no unarrived parties then there is a flaw
105
+ // in the register/arrive/unregister logic in another method that uses the phaser
106
+ assert stopPhaser .getRegisteredParties () > 0 ;
107
+ assert stopPhaser .getUnarrivedParties () > 0 ;
108
+ stopPhaser .arriveAndAwaitAdvance ();
109
+ assert stopPhaser .getPhase () > 0 ;
110
+ logger .debug ("ML memory tracker stopped" );
111
+ }
112
+
93
113
@ Override
94
114
public String executorName () {
95
115
return MachineLearning .UTILITY_THREAD_POOL_NAME ;
@@ -153,13 +173,13 @@ public boolean asyncRefresh() {
153
173
try {
154
174
ActionListener <Void > listener = ActionListener .wrap (
155
175
aVoid -> logger .trace ("Job memory requirement refresh request completed successfully" ),
156
- e -> logger .error ("Failed to refresh job memory requirements" , e )
176
+ e -> logger .warn ("Failed to refresh job memory requirements" , e )
157
177
);
158
178
threadPool .executor (executorName ()).execute (
159
179
() -> refresh (clusterService .state ().getMetaData ().custom (PersistentTasksCustomMetaData .TYPE ), listener ));
160
180
return true ;
161
181
} catch (EsRejectedExecutionException e ) {
162
- logger .debug ("Couldn't schedule ML memory update - node might be shutting down" , e );
182
+ logger .warn ("Couldn't schedule ML memory update - node might be shutting down" , e );
163
183
}
164
184
}
165
185
@@ -253,25 +273,43 @@ public void refreshJobMemory(String jobId, ActionListener<Long> listener) {
253
273
return ;
254
274
}
255
275
276
+ // The phaser prevents searches being started after the memory tracker's stop() method has returned
277
+ if (stopPhaser .register () != 0 ) {
278
+ // Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
279
+ stopPhaser .arriveAndDeregister ();
280
+ listener .onFailure (new EsRejectedExecutionException ("Couldn't run ML memory update - node is shutting down" ));
281
+ return ;
282
+ }
283
+ ActionListener <Long > phaserListener = ActionListener .wrap (
284
+ r -> {
285
+ stopPhaser .arriveAndDeregister ();
286
+ listener .onResponse (r );
287
+ },
288
+ e -> {
289
+ stopPhaser .arriveAndDeregister ();
290
+ listener .onFailure (e );
291
+ }
292
+ );
293
+
256
294
try {
257
295
jobResultsProvider .getEstablishedMemoryUsage (jobId , null , null ,
258
296
establishedModelMemoryBytes -> {
259
297
if (establishedModelMemoryBytes <= 0L ) {
260
- setJobMemoryToLimit (jobId , listener );
298
+ setJobMemoryToLimit (jobId , phaserListener );
261
299
} else {
262
300
Long memoryRequirementBytes = establishedModelMemoryBytes + Job .PROCESS_MEMORY_OVERHEAD .getBytes ();
263
301
memoryRequirementByJob .put (jobId , memoryRequirementBytes );
264
- listener .onResponse (memoryRequirementBytes );
302
+ phaserListener .onResponse (memoryRequirementBytes );
265
303
}
266
304
},
267
305
e -> {
268
306
logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
269
- setJobMemoryToLimit (jobId , listener );
307
+ setJobMemoryToLimit (jobId , phaserListener );
270
308
}
271
309
);
272
310
} catch (Exception e ) {
273
311
logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
274
- setJobMemoryToLimit (jobId , listener );
312
+ setJobMemoryToLimit (jobId , phaserListener );
275
313
}
276
314
}
277
315
0 commit comments