Skip to content

Commit

Permalink
TaskExecutor to cancel all tasks on exception (#12689)
Browse files Browse the repository at this point in the history
When operations are parallelized, like query rewrite, or search, or
createWeight, one of the tasks may throw an exception. In that case we
wait for all tasks to be completed before re-throwing the exception that
were caught. Tasks that were not started when the exception is captured
though can be safely skipped. Ideally we would also cancel ongoing tasks
but I left that for another time.
  • Loading branch information
javanna authored Oct 24, 2023
1 parent 71c4ea7 commit 1200ecc
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 70 deletions.
164 changes: 112 additions & 52 deletions lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.ThreadInterruptedException;

Expand Down Expand Up @@ -64,70 +67,127 @@ public final class TaskExecutor {
* @param <T> the return type of the task execution
*/
public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException {
List<Task<T>> tasks = new ArrayList<>(callables.size());
boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0;
for (Callable<T> callable : callables) {
Task<T> task = new Task<>(callable);
tasks.add(task);
if (runOnCallerThread) {
task.run();
} else {
executor.execute(task);
TaskGroup<T> taskGroup = new TaskGroup<>(callables);
return taskGroup.invokeAll(executor);
}

@Override
public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}

/**
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and
* exposes the ability to invoke such tasks and wait for them all to complete their execution and
* provide their results. Ensures that each task does not get parallelized further: this is
* important to avoid a deadlock in situations where one executor thread waits on other executor
* threads to complete before it can progress. This happens in situations where for instance
* {@link Query#createWeight(IndexSearcher, ScoreMode, float)} is called as part of searching each
* slice, like {@link TopFieldCollector#populateScores(ScoreDoc[], IndexSearcher, Query)} does.
* Additionally, if one task throws an exception, all other tasks from the same group are
* cancelled, to avoid needless computation as their results would not be exposed anyways. Creates
* one {@link FutureTask} for each {@link Callable} provided
*
* @param <T> the return type of all the callables
*/
private static final class TaskGroup<T> {
private final Collection<RunnableFuture<T>> futures;

TaskGroup(Collection<Callable<T>> callables) {
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
tasks.add(createTask(callable));
}
this.futures = Collections.unmodifiableCollection(tasks);
}

Throwable exc = null;
final List<T> results = new ArrayList<>();
for (Future<T> future : tasks) {
try {
results.add(future.get());
} catch (InterruptedException e) {
var newException = new ThreadInterruptedException(e);
if (exc == null) {
exc = newException;
} else {
exc.addSuppressed(newException);
RunnableFuture<T> createTask(Callable<T> callable) {
// -1: cancelled; 0: not yet started; 1: started
AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
return new FutureTask<>(
() -> {
if (startedOrCancelled.compareAndSet(false, true)) {
try {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter + 1);
return callable.call();
} catch (Throwable t) {
cancelAll();
throw t;
} finally {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter - 1);
}
}
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
return null;
}) {
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false
: "cancelling tasks that are running is not supported";
/*
Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
return startedOrCancelled.compareAndSet(false, true);
}
} catch (ExecutionException e) {
if (exc == null) {
exc = e.getCause();
};
}

List<T> invokeAll(Executor executor) throws IOException {
boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0;
for (Runnable runnable : futures) {
if (runOnCallerThread) {
runnable.run();
} else {
exc.addSuppressed(e.getCause());
executor.execute(runnable);
}
}
Throwable exc = null;
List<T> results = new ArrayList<>(futures.size());
for (Future<T> future : futures) {
try {
results.add(future.get());
} catch (InterruptedException e) {
var newException = new ThreadInterruptedException(e);
if (exc == null) {
exc = newException;
} else {
exc.addSuppressed(newException);
}
} catch (ExecutionException e) {
if (exc == null) {
exc = e.getCause();
} else {
exc.addSuppressed(e.getCause());
}
}
}
assert assertAllFuturesCompleted() : "Some tasks are still running?";
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
}
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
}

/**
* Extension of {@link FutureTask} that tracks the number of tasks that are running in each
* thread.
*
* @param <V> the return type of the task
*/
private static final class Task<V> extends FutureTask<V> {
private Task(Callable<V> callable) {
super(callable);
private boolean assertAllFuturesCompleted() {
for (RunnableFuture<T> future : futures) {
if (future.isDone() == false) {
return false;
}
}
return true;
}

@Override
public void run() {
try {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter + 1);
super.run();
} finally {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter - 1);
private void cancelAll() {
for (Future<T> future : futures) {
future.cancel(false);
}
}
}

@Override
public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}
}
105 changes: 87 additions & 18 deletions lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
Expand All @@ -32,6 +34,8 @@
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.AfterClass;
import org.junit.BeforeClass;

Expand All @@ -43,7 +47,8 @@ public class TestTaskExecutor extends LuceneTestCase {
public static void createExecutor() {
executorService =
Executors.newFixedThreadPool(
1, new NamedThreadFactory(TestTaskExecutor.class.getSimpleName()));
random().nextBoolean() ? 1 : 2,
new NamedThreadFactory(TestTaskExecutor.class.getSimpleName()));
}

@AfterClass
Expand Down Expand Up @@ -228,11 +233,21 @@ public Void reduce(Collection<Collector> collectors) {
}

public void testInvokeAllDoesNotLeaveTasksBehind() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
AtomicInteger tasksStarted = new AtomicInteger(0);
TaskExecutor taskExecutor =
new TaskExecutor(
command -> {
executorService.execute(
() -> {
tasksStarted.incrementAndGet();
command.run();
});
});
AtomicInteger tasksExecuted = new AtomicInteger(0);
List<Callable<Void>> callables = new ArrayList<>();
callables.add(
() -> {
tasksExecuted.incrementAndGet();
throw new RuntimeException();
});
int tasksWithNormalExit = 99;
Expand All @@ -244,7 +259,14 @@ public void testInvokeAllDoesNotLeaveTasksBehind() {
});
}
expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables));
assertEquals(tasksWithNormalExit, tasksExecuted.get());
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
if (maximumPoolSize == 1) {
assertEquals(1, tasksExecuted.get());
} else {
MatcherAssert.assertThat(tasksExecuted.get(), Matchers.greaterThanOrEqualTo(1));
}
// the callables are technically all run, but the cancelled ones will be no-op
assertEquals(100, tasksStarted.get());
}

/**
Expand All @@ -253,36 +275,83 @@ public void testInvokeAllDoesNotLeaveTasksBehind() {
*/
public void testInvokeAllCatchesMultipleExceptions() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
AtomicInteger tasksExecuted = new AtomicInteger(0);
List<Callable<Void>> callables = new ArrayList<>();
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
// if we have multiple threads, make sure both are started before an exception is thrown,
// otherwise there may or may not be a suppressed exception
CountDownLatch latchA = new CountDownLatch(1);
CountDownLatch latchB = new CountDownLatch(1);
callables.add(
() -> {
if (maximumPoolSize > 1) {
latchA.countDown();
latchB.await();
}
throw new RuntimeException("exception A");
});
int tasksWithNormalExit = 50;
for (int i = 0; i < tasksWithNormalExit; i++) {
callables.add(
() -> {
tasksExecuted.incrementAndGet();
return null;
});
}
callables.add(
() -> {
if (maximumPoolSize > 1) {
latchB.countDown();
latchA.await();
}
throw new IllegalStateException("exception B");
});

RuntimeException exc =
expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables));
Throwable[] suppressed = exc.getSuppressed();
assertEquals(1, suppressed.length);
if (exc.getMessage().equals("exception A")) {
assertEquals("exception B", suppressed[0].getMessage());

if (maximumPoolSize == 1) {
assertEquals(0, suppressed.length);
} else {
assertEquals("exception A", suppressed[0].getMessage());
assertEquals("exception B", exc.getMessage());
assertEquals(1, suppressed.length);
if (exc.getMessage().equals("exception A")) {
assertEquals("exception B", suppressed[0].getMessage());
} else {
assertEquals("exception A", suppressed[0].getMessage());
assertEquals("exception B", exc.getMessage());
}
}
}

assertEquals(tasksWithNormalExit, tasksExecuted.get());
public void testCancelTasksOnException() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
final int numTasks = random().nextInt(10, 50);
final int throwingTask = random().nextInt(numTasks);
boolean error = random().nextBoolean();
List<Callable<Void>> tasks = new ArrayList<>(numTasks);
AtomicInteger executedTasks = new AtomicInteger(0);
for (int i = 0; i < numTasks; i++) {
final int index = i;
tasks.add(
() -> {
if (index == throwingTask) {
if (error) {
throw new OutOfMemoryError();
} else {
throw new RuntimeException();
}
}
if (index > throwingTask && maximumPoolSize == 1) {
throw new AssertionError("task should not have started");
}
executedTasks.incrementAndGet();
return null;
});
}
Throwable throwable;
if (error) {
throwable = expectThrows(OutOfMemoryError.class, () -> taskExecutor.invokeAll(tasks));
} else {
throwable = expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(tasks));
}
assertEquals(0, throwable.getSuppressed().length);
if (maximumPoolSize == 1) {
assertEquals(throwingTask, executedTasks.get());
} else {
MatcherAssert.assertThat(executedTasks.get(), Matchers.greaterThanOrEqualTo(throwingTask));
}
}
}

0 comments on commit 1200ecc

Please sign in to comment.