Skip to content

Commit 3c510b2

Browse files
committed
Prefer killing speculative tasks in LeastWastedEffortTaskLowMemoryKiller
1 parent ac2d379 commit 3c510b2

File tree

2 files changed

+89
-31
lines changed

2 files changed

+89
-31
lines changed

core/trino-main/src/main/java/io/trino/memory/LeastWastedEffortTaskLowMemoryKiller.java

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.Map;
3030
import java.util.Optional;
3131
import java.util.Set;
32+
import java.util.stream.Stream;
3233

3334
import static com.google.common.collect.ImmutableMap.toImmutableMap;
3435
import static com.google.common.collect.ImmutableSet.toImmutableSet;
@@ -69,21 +70,8 @@ public Optional<KillTarget> chooseTargetToKill(List<RunningQueryInfo> runningQue
6970
continue;
7071
}
7172

72-
memoryPool.getTaskMemoryReservations().entrySet().stream()
73-
.map(entry -> new SimpleEntry<>(TaskId.valueOf(entry.getKey()), entry.getValue()))
74-
.filter(entry -> queriesWithTaskRetryPolicy.contains(entry.getKey().getQueryId()))
75-
.max(comparing(entry -> {
76-
TaskId taskId = entry.getKey();
77-
Long memoryUsed = entry.getValue();
78-
long wallTime = 0;
79-
if (taskInfos.containsKey(taskId)) {
80-
TaskStats stats = taskInfos.get(taskId).getStats();
81-
wallTime = stats.getTotalScheduledTime().toMillis() + stats.getTotalBlockedTime().toMillis();
82-
}
83-
wallTime = Math.max(wallTime, MIN_WALL_TIME); // only look at memory consumption for fairly short-lived tasks
84-
return (double) memoryUsed / wallTime;
85-
}))
86-
.map(SimpleEntry::getKey)
73+
findBiggestTask(queriesWithTaskRetryPolicy, taskInfos, memoryPool, true) // try just speculative
74+
.or(() -> findBiggestTask(queriesWithTaskRetryPolicy, taskInfos, memoryPool, false)) // fallback to any task
8775
.ifPresent(tasksToKillBuilder::add);
8876
}
8977
Set<TaskId> tasksToKill = tasksToKillBuilder.build();
@@ -92,4 +80,35 @@ public Optional<KillTarget> chooseTargetToKill(List<RunningQueryInfo> runningQue
9280
}
9381
return Optional.of(KillTarget.selectedTasks(tasksToKill));
9482
}
83+
84+
private static Optional<TaskId> findBiggestTask(Set<QueryId> queriesWithTaskRetryPolicy, Map<TaskId, TaskInfo> taskInfos, MemoryPoolInfo memoryPool, boolean onlySpeculative)
85+
{
86+
Stream<SimpleEntry<TaskId, Long>> stream = memoryPool.getTaskMemoryReservations().entrySet().stream()
87+
.map(entry -> new SimpleEntry<>(TaskId.valueOf(entry.getKey()), entry.getValue()))
88+
.filter(entry -> queriesWithTaskRetryPolicy.contains(entry.getKey().getQueryId()));
89+
90+
if (onlySpeculative) {
91+
stream = stream.filter(entry -> {
92+
TaskInfo taskInfo = taskInfos.get(entry.getKey());
93+
if (taskInfo == null) {
94+
return false;
95+
}
96+
return taskInfo.getTaskStatus().isSpeculative();
97+
});
98+
}
99+
100+
return stream
101+
.max(comparing(entry -> {
102+
TaskId taskId = entry.getKey();
103+
Long memoryUsed = entry.getValue();
104+
long wallTime = 0;
105+
if (taskInfos.containsKey(taskId)) {
106+
TaskStats stats = taskInfos.get(taskId).getStats();
107+
wallTime = stats.getTotalScheduledTime().toMillis() + stats.getTotalBlockedTime().toMillis();
108+
}
109+
wallTime = Math.max(wallTime, MIN_WALL_TIME); // only look at memory consumption for fairly short-lived tasks
110+
return (double) memoryUsed / wallTime;
111+
}))
112+
.map(SimpleEntry::getKey);
113+
}
95114
}

core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,16 @@ private void testKillsBiggestTasksIfAllExecuteSameTime(Duration scheduledTime, D
150150
else {
151151
taskInfos = ImmutableMap.of(
152152
"q_1", ImmutableMap.of(
153-
1, buildTaskInfo(taskId("q_1", 1), TaskState.RUNNING, scheduledTime, blockedTime)),
153+
1, buildTaskInfo(taskId("q_1", 1), TaskState.RUNNING, scheduledTime, blockedTime, false)),
154154
"q_2", ImmutableMap.of(
155-
1, buildTaskInfo(taskId("q_2", 1), TaskState.RUNNING, scheduledTime, blockedTime),
156-
2, buildTaskInfo(taskId("q_2", 2), TaskState.RUNNING, scheduledTime, blockedTime),
157-
3, buildTaskInfo(taskId("q_2", 3), TaskState.RUNNING, scheduledTime, blockedTime),
158-
4, buildTaskInfo(taskId("q_2", 4), TaskState.RUNNING, scheduledTime, blockedTime),
159-
5, buildTaskInfo(taskId("q_2", 5), TaskState.RUNNING, scheduledTime, blockedTime),
160-
6, buildTaskInfo(taskId("q_2", 6), TaskState.RUNNING, scheduledTime, blockedTime),
161-
7, buildTaskInfo(taskId("q_2", 7), TaskState.RUNNING, scheduledTime, blockedTime),
162-
8, buildTaskInfo(taskId("q_2", 8), TaskState.RUNNING, scheduledTime, blockedTime)));
155+
1, buildTaskInfo(taskId("q_2", 1), TaskState.RUNNING, scheduledTime, blockedTime, false),
156+
2, buildTaskInfo(taskId("q_2", 2), TaskState.RUNNING, scheduledTime, blockedTime, false),
157+
3, buildTaskInfo(taskId("q_2", 3), TaskState.RUNNING, scheduledTime, blockedTime, false),
158+
4, buildTaskInfo(taskId("q_2", 4), TaskState.RUNNING, scheduledTime, blockedTime, false),
159+
5, buildTaskInfo(taskId("q_2", 5), TaskState.RUNNING, scheduledTime, blockedTime, false),
160+
6, buildTaskInfo(taskId("q_2", 6), TaskState.RUNNING, scheduledTime, blockedTime, false),
161+
7, buildTaskInfo(taskId("q_2", 7), TaskState.RUNNING, scheduledTime, blockedTime, false),
162+
8, buildTaskInfo(taskId("q_2", 8), TaskState.RUNNING, scheduledTime, blockedTime, false)));
163163
}
164164

165165
assertEquals(
@@ -194,12 +194,12 @@ public void testKillsSmallerTaskIfWastedEffortRatioIsBetter()
194194

195195
Map<String, Map<Integer, TaskInfo>> taskInfos = ImmutableMap.of(
196196
"q_1", ImmutableMap.of(
197-
1, buildTaskInfo(taskId("q_1", 1), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS)),
198-
2, buildTaskInfo(taskId("q_1", 2), TaskState.RUNNING, new Duration(400, SECONDS), new Duration(200, SECONDS))),
197+
1, buildTaskInfo(taskId("q_1", 1), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS), false),
198+
2, buildTaskInfo(taskId("q_1", 2), TaskState.RUNNING, new Duration(400, SECONDS), new Duration(200, SECONDS), false)),
199199
"q_2", ImmutableMap.of(
200-
1, buildTaskInfo(taskId("q_2", 1), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS)),
201-
2, buildTaskInfo(taskId("q_2", 2), TaskState.RUNNING, new Duration(100, SECONDS), new Duration(100, SECONDS)),
202-
3, buildTaskInfo(taskId("q_2", 3), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS))));
200+
1, buildTaskInfo(taskId("q_2", 1), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS), false),
201+
2, buildTaskInfo(taskId("q_2", 2), TaskState.RUNNING, new Duration(100, SECONDS), new Duration(100, SECONDS), false),
202+
3, buildTaskInfo(taskId("q_2", 3), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS), false)));
203203

204204
// q1_1; n1; walltime 60s; memory 3; ratio 0.05 (pick for n1)
205205
// q1_2; n2; walltime 600s; memory 8; ratio 0.0133
@@ -217,7 +217,46 @@ public void testKillsSmallerTaskIfWastedEffortRatioIsBetter()
217217
taskId("q_2", 3)))));
218218
}
219219

220-
private static TaskInfo buildTaskInfo(TaskId taskId, TaskState state, Duration scheduledTime, Duration blockedTime)
220+
@Test
221+
public void testPrefersKillingSpeculativeTasks()
222+
{
223+
int memoryPool = 8;
224+
Map<String, Map<String, Long>> queries = ImmutableMap.<String, Map<String, Long>>builder()
225+
.put("q_1", ImmutableMap.of("n1", 3L, "n2", 8L))
226+
.put("q_2", ImmutableMap.of("n1", 7L, "n2", 2L))
227+
.buildOrThrow();
228+
229+
Map<String, Map<String, Map<Integer, Long>>> tasks = ImmutableMap.<String, Map<String, Map<Integer, Long>>>builder()
230+
.put("q_1", ImmutableMap.of(
231+
"n1", ImmutableMap.of(1, 3L),
232+
"n2", ImmutableMap.of(2, 8L)))
233+
.put("q_2", ImmutableMap.of(
234+
"n1", ImmutableMap.of(
235+
1, 1L,
236+
2, 6L),
237+
"n2", ImmutableMap.of(3, 2L)))
238+
.buildOrThrow();
239+
240+
Map<String, Map<Integer, TaskInfo>> taskInfos = ImmutableMap.of(
241+
"q_1", ImmutableMap.of(
242+
1, buildTaskInfo(taskId("q_1", 1), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS), false),
243+
2, buildTaskInfo(taskId("q_1", 2), TaskState.RUNNING, new Duration(400, SECONDS), new Duration(200, SECONDS), false)),
244+
"q_2", ImmutableMap.of(
245+
1, buildTaskInfo(taskId("q_2", 1), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS), true),
246+
2, buildTaskInfo(taskId("q_2", 2), TaskState.RUNNING, new Duration(100, SECONDS), new Duration(100, SECONDS), false),
247+
3, buildTaskInfo(taskId("q_2", 3), TaskState.RUNNING, new Duration(30, SECONDS), new Duration(30, SECONDS), false)));
248+
249+
assertEquals(
250+
lowMemoryKiller.chooseTargetToKill(
251+
toRunningQueryInfoList(queries, ImmutableSet.of("q_1", "q_2"), taskInfos),
252+
toNodeMemoryInfoList(memoryPool, queries, tasks)),
253+
Optional.of(KillTarget.selectedTasks(
254+
ImmutableSet.of(
255+
taskId("q_2", 1), // if q_2_1 was not speculative then "q_1_1 would be picked
256+
taskId("q_2", 3)))));
257+
}
258+
259+
private static TaskInfo buildTaskInfo(TaskId taskId, TaskState state, Duration scheduledTime, Duration blockedTime, boolean speculative)
221260
{
222261
return new TaskInfo(
223262
new TaskStatus(
@@ -227,7 +266,7 @@ private static TaskInfo buildTaskInfo(TaskId taskId, TaskState state, Duration s
227266
state,
228267
URI.create("fake://task/" + taskId + "/node/some_node"),
229268
"some_node",
230-
false,
269+
speculative,
231270
ImmutableList.of(),
232271
0,
233272
0,

0 commit comments

Comments
 (0)