Skip to content

Commit dccd684

Browse files
committed
[ML] Stop the ML memory tracker before closing node (#39111)
The ML memory tracker does searches against ML results and config indices. These searches can be asynchronous, and if they are running while the node is closing then they can cause problems for other components. This change adds a stop() method to the MlMemoryTracker that waits for in-flight searches to complete. Once stop() has returned the MlMemoryTracker will not kick off any new searches. The MlLifeCycleService now calls MlMemoryTracker.stop() before stopping stopping the node. Fixes #37117
1 parent e83361d commit dccd684

File tree

4 files changed

+68
-13
lines changed

4 files changed

+68
-13
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,10 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
428428
DatafeedManager datafeedManager = new DatafeedManager(threadPool, client, clusterService, datafeedJobBuilder,
429429
System::currentTimeMillis, auditor, autodetectProcessManager);
430430
this.datafeedManager.set(datafeedManager);
431-
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
432-
autodetectProcessManager);
433431
MlMemoryTracker memoryTracker = new MlMemoryTracker(settings, clusterService, threadPool, jobManager, jobResultsProvider);
434432
this.memoryTracker.set(memoryTracker);
433+
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
434+
autodetectProcessManager, memoryTracker);
435435

436436
// This object's constructor attaches to the license state, so there's no need to retain another reference to it
437437
new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.common.component.LifecycleListener;
1111
import org.elasticsearch.env.Environment;
1212
import org.elasticsearch.xpack.ml.datafeed.DatafeedManager;
13+
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
1314
import org.elasticsearch.xpack.ml.process.NativeController;
1415
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
1516
import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
@@ -21,16 +22,14 @@ public class MlLifeCycleService extends AbstractComponent {
2122
private final Environment environment;
2223
private final DatafeedManager datafeedManager;
2324
private final AutodetectProcessManager autodetectProcessManager;
24-
25-
public MlLifeCycleService(Environment environment, ClusterService clusterService) {
26-
this(environment, clusterService, null, null);
27-
}
25+
private final MlMemoryTracker memoryTracker;
2826

2927
public MlLifeCycleService(Environment environment, ClusterService clusterService, DatafeedManager datafeedManager,
30-
AutodetectProcessManager autodetectProcessManager) {
28+
AutodetectProcessManager autodetectProcessManager, MlMemoryTracker memoryTracker) {
3129
this.environment = environment;
3230
this.datafeedManager = datafeedManager;
3331
this.autodetectProcessManager = autodetectProcessManager;
32+
this.memoryTracker = memoryTracker;
3433
clusterService.addLifecycleListener(new LifecycleListener() {
3534
@Override
3635
public void beforeStop() {
@@ -60,5 +59,8 @@ public synchronized void stop() {
6059
} catch (IOException e) {
6160
// We're stopping anyway, so don't let this complicate the shutdown sequence
6261
}
62+
if (memoryTracker != null) {
63+
memoryTracker.stop();
64+
}
6365
}
6466
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.Iterator;
3434
import java.util.List;
3535
import java.util.concurrent.ConcurrentHashMap;
36+
import java.util.concurrent.Phaser;
3637
import java.util.stream.Collectors;
3738

3839
/**
@@ -56,6 +57,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
5657
private final ClusterService clusterService;
5758
private final JobManager jobManager;
5859
private final JobResultsProvider jobResultsProvider;
60+
private final Phaser stopPhaser;
5961
private volatile boolean isMaster;
6062
private volatile Instant lastUpdateTime;
6163
private volatile Duration reassignmentRecheckInterval;
@@ -66,6 +68,7 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP
6668
this.clusterService = clusterService;
6769
this.jobManager = jobManager;
6870
this.jobResultsProvider = jobResultsProvider;
71+
this.stopPhaser = new Phaser(1);
6972
setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings));
7073
clusterService.addLocalNodeMasterListener(this);
7174
clusterService.getClusterSettings().addSettingsUpdateConsumer(
@@ -90,6 +93,23 @@ public void offMaster() {
9093
lastUpdateTime = null;
9194
}
9295

96+
/**
97+
* Wait for all outstanding searches to complete.
98+
* After returning, no new searches can be started.
99+
*/
100+
public void stop() {
101+
logger.trace("ML memory tracker stop called");
102+
// We never terminate the phaser
103+
assert stopPhaser.isTerminated() == false;
104+
// If there are no registered parties or no unarrived parties then there is a flaw
105+
// in the register/arrive/unregister logic in another method that uses the phaser
106+
assert stopPhaser.getRegisteredParties() > 0;
107+
assert stopPhaser.getUnarrivedParties() > 0;
108+
stopPhaser.arriveAndAwaitAdvance();
109+
assert stopPhaser.getPhase() > 0;
110+
logger.debug("ML memory tracker stopped");
111+
}
112+
93113
@Override
94114
public String executorName() {
95115
return MachineLearning.UTILITY_THREAD_POOL_NAME;
@@ -153,13 +173,13 @@ public boolean asyncRefresh() {
153173
try {
154174
ActionListener<Void> listener = ActionListener.wrap(
155175
aVoid -> logger.trace("Job memory requirement refresh request completed successfully"),
156-
e -> logger.error("Failed to refresh job memory requirements", e)
176+
e -> logger.warn("Failed to refresh job memory requirements", e)
157177
);
158178
threadPool.executor(executorName()).execute(
159179
() -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), listener));
160180
return true;
161181
} catch (EsRejectedExecutionException e) {
162-
logger.debug("Couldn't schedule ML memory update - node might be shutting down", e);
182+
logger.warn("Couldn't schedule ML memory update - node might be shutting down", e);
163183
}
164184
}
165185

@@ -253,25 +273,43 @@ public void refreshJobMemory(String jobId, ActionListener<Long> listener) {
253273
return;
254274
}
255275

276+
// The phaser prevents searches being started after the memory tracker's stop() method has returned
277+
if (stopPhaser.register() != 0) {
278+
// Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
279+
stopPhaser.arriveAndDeregister();
280+
listener.onFailure(new EsRejectedExecutionException("Couldn't run ML memory update - node is shutting down"));
281+
return;
282+
}
283+
ActionListener<Long> phaserListener = ActionListener.wrap(
284+
r -> {
285+
stopPhaser.arriveAndDeregister();
286+
listener.onResponse(r);
287+
},
288+
e -> {
289+
stopPhaser.arriveAndDeregister();
290+
listener.onFailure(e);
291+
}
292+
);
293+
256294
try {
257295
jobResultsProvider.getEstablishedMemoryUsage(jobId, null, null,
258296
establishedModelMemoryBytes -> {
259297
if (establishedModelMemoryBytes <= 0L) {
260-
setJobMemoryToLimit(jobId, listener);
298+
setJobMemoryToLimit(jobId, phaserListener);
261299
} else {
262300
Long memoryRequirementBytes = establishedModelMemoryBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes();
263301
memoryRequirementByJob.put(jobId, memoryRequirementBytes);
264-
listener.onResponse(memoryRequirementBytes);
302+
phaserListener.onResponse(memoryRequirementBytes);
265303
}
266304
},
267305
e -> {
268306
logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e);
269-
setJobMemoryToLimit(jobId, listener);
307+
setJobMemoryToLimit(jobId, phaserListener);
270308
}
271309
);
272310
} catch (Exception e) {
273311
logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e);
274-
setJobMemoryToLimit(jobId, listener);
312+
setJobMemoryToLimit(jobId, phaserListener);
275313
}
276314
}
277315

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.common.settings.ClusterSettings;
1111
import org.elasticsearch.common.settings.Settings;
1212
import org.elasticsearch.common.unit.ByteSizeUnit;
13+
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
1314
import org.elasticsearch.persistent.PersistentTasksClusterService;
1415
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
1516
import org.elasticsearch.test.ESTestCase;
@@ -29,6 +30,7 @@
2930
import java.util.concurrent.atomic.AtomicReference;
3031
import java.util.function.Consumer;
3132

33+
import static org.hamcrest.CoreMatchers.instanceOf;
3234
import static org.mockito.Matchers.any;
3335
import static org.mockito.Matchers.eq;
3436
import static org.mockito.Mockito.anyString;
@@ -157,6 +159,19 @@ public void testRefreshOne() {
157159
assertNull(memoryTracker.getJobMemoryRequirement(jobId));
158160
}
159161

162+
public void testStop() {
163+
164+
memoryTracker.onMaster();
165+
memoryTracker.stop();
166+
167+
AtomicReference<Exception> exception = new AtomicReference<>();
168+
memoryTracker.refreshJobMemory("job", ActionListener.wrap(ESTestCase::assertNull, exception::set));
169+
170+
assertNotNull(exception.get());
171+
assertThat(exception.get(), instanceOf(EsRejectedExecutionException.class));
172+
assertEquals("Couldn't run ML memory update - node is shutting down", exception.get().getMessage());
173+
}
174+
160175
private PersistentTasksCustomMetaData.PersistentTask<OpenJobAction.JobParams> makeTestTask(String jobId) {
161176
return new PersistentTasksCustomMetaData.PersistentTask<>("job-" + jobId, MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams(jobId),
162177
0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT);

0 commit comments

Comments
 (0)