Skip to content

Fix exception propagation in Async API methods #1479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.mongodb.lang.Nullable;

import java.util.concurrent.atomic.AtomicBoolean;

/**
* See {@link AsyncRunnable}
* <p>
Expand All @@ -33,4 +35,28 @@ public interface AsyncFunction<T, R> {
* @param callback the callback
*/
void unsafeFinish(T value, SingleResultCallback<R> callback);

/**
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
*
* @param callback the callback provided by the method the chain is used in.
*/
default void finish(final T value, final SingleResultCallback<R> callback) {
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish(value, (v, e) -> {
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
runnable.unsafeFinish(c);
/* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()',
then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */
runnable.finish(c);
} else {
c.completeExceptionally(e);
}
Expand Down Expand Up @@ -236,7 +238,7 @@ default AsyncRunnable thenRunIf(final Supplier<Boolean> condition, final AsyncRu
return;
}
if (matched) {
runnable.unsafeFinish(callback);
runnable.finish(callback);
} else {
callback.complete(callback);
}
Expand All @@ -253,7 +255,7 @@ default <R> AsyncSupplier<R> thenSupply(final AsyncSupplier<R> supplier) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
supplier.unsafeFinish(c);
supplier.finish(c);
} else {
c.completeExceptionally(e);
}
Expand Down
24 changes: 16 additions & 8 deletions driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.mongodb.lang.Nullable;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;


Expand Down Expand Up @@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
}

/**
* Must be invoked at end of async chain.
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
*
* @see #thenApply(AsyncFunction)
* @see #thenConsume(AsyncConsumer)
* @see #onErrorIf(Predicate, AsyncFunction)
* @param callback the callback provided by the method the chain is used in
*/
default void finish(final SingleResultCallback<T> callback) {
final boolean[] callbackInvoked = {false};
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish((v, e) -> {
callbackInvoked[0] = true;
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (callbackInvoked[0]) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
Expand All @@ -80,9 +88,9 @@ default void finish(final SingleResultCallback<T> callback) {
*/
default <R> AsyncSupplier<R> thenApply(final AsyncFunction<T, R> function) {
return (c) -> {
this.unsafeFinish((v, e) -> {
this.finish((v, e) -> {
if (e == null) {
function.unsafeFinish(v, c);
function.finish(v, c);
} else {
c.completeExceptionally(e);
}
Expand All @@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer<T> consumer) {
return (c) -> {
this.unsafeFinish((v, e) -> {
if (e == null) {
consumer.unsafeFinish(v, c);
consumer.finish(v, c);
} else {
c.completeExceptionally(e);
}
Expand Down Expand Up @@ -131,7 +139,7 @@ default AsyncSupplier<T> onErrorIf(
return;
}
if (errorMatched) {
errorFunction.unsafeFinish(e, callback);
errorFunction.finish(e, callback);
} else {
callback.completeExceptionally(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
return;
}
assertNotNull(responseBuffers);
T commandResult;
try {
updateSessionContext(operationContext.getSessionContext(), responseBuffers);
boolean commandOk =
Expand All @@ -624,13 +625,14 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
}
commandEventSender.sendSucceededEvent(responseBuffers);

T result1 = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext());
callback.onResult(result1, null);
commandResult = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext());
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
return;
} finally {
responseBuffers.close();
}
callback.onResult(commandResult, null);
}));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection, fin
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
} else {
setSpeculativeAuthenticateResponse(helloResult);
callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null);
InternalConnectionInitializationDescription initializationDescription;
try {
initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime);
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
return;
}
callback.onResult(initializationDescription, null);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static org.junit.jupiter.api.Assertions.assertThrows;

final class AsyncFunctionsTest extends AsyncFunctionsTestAbstract {
abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));

@Test
void test1Method() {
// the number of expected variations is often: 1 + N methods invoked
Expand Down Expand Up @@ -760,25 +760,6 @@ void testVariables() {
});
}

@Test
void testInvalid() {
setIsTestingAbruptCompletion(false);
setAsyncStep(true);
assertThrows(IllegalStateException.class, () -> {
beginAsync().thenRun(c -> {
async(3, c);
throw new IllegalStateException("must not cause second callback invocation");
}).finish((v, e) -> {});
});
assertThrows(IllegalStateException.class, () -> {
beginAsync().thenRun(c -> {
async(3, c);
}).finish((v, e) -> {
throw new IllegalStateException("must not cause second callback invocation");
});
});
}

@Test
void testDerivation() {
// Demonstrates the progression from nested async to the API.
Expand Down Expand Up @@ -866,5 +847,4 @@ void testDerivation() {
}).finish(callback);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@
package com.mongodb.internal.async;

import com.mongodb.client.TestListener;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.opentest4j.AssertionFailedError;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
Expand All @@ -31,11 +37,12 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

public class AsyncFunctionsTestAbstract {
public abstract class AsyncFunctionsTestBase {

private final TestListener listener = new TestListener();
private final InvocationTracker invocationTracker = new InvocationTracker();
private boolean isTestingAbruptCompletion = false;
private ExecutorService asyncExecutor;

void setIsTestingAbruptCompletion(final boolean b) {
isTestingAbruptCompletion = b;
Expand All @@ -53,6 +60,23 @@ public void listenerAdd(final String s) {
listener.add(s);
}

/**
* Create an executor service for async operations before each test.
*
* @return the executor service.
*/
public abstract ExecutorService createAsyncExecutor();

@BeforeEach
public void setUp() {
asyncExecutor = createAsyncExecutor();
}

@AfterEach
public void shutDown() {
asyncExecutor.shutdownNow();
}

void plain(final int i) {
int cur = invocationTracker.getNextOption(2);
if (cur == 0) {
Expand Down Expand Up @@ -98,32 +122,47 @@ Integer syncReturns(final int i) {
return affectedReturns(i);
}


public void submit(final Runnable task) {
asyncExecutor.execute(task);
}
void async(final int i, final SingleResultCallback<Void> callback) {
assertTrue(invocationTracker.isAsyncStep);
if (isTestingAbruptCompletion) {
/* We should not test for abrupt completion in a separate thread. Once a callback is registered for an async operation,
the Async Framework does not handle exceptions thrown outside of callbacks by the executing thread. Such exception management
should be the responsibility of the thread conducting the asynchronous operations. */
affected(i);
callback.complete(callback);

} else {
try {
affected(i);
submit(() -> {
callback.complete(callback);
} catch (Throwable t) {
callback.onResult(null, t);
}
});
} else {
submit(() -> {
try {
affected(i);
callback.complete(callback);
} catch (Throwable t) {
callback.onResult(null, t);
}
});
}
}

void asyncReturns(final int i, final SingleResultCallback<Integer> callback) {
assertTrue(invocationTracker.isAsyncStep);
if (isTestingAbruptCompletion) {
callback.complete(affectedReturns(i));
int result = affectedReturns(i);
submit(() -> {
callback.complete(result);
});
} else {
try {
callback.complete(affectedReturns(i));
} catch (Throwable t) {
callback.onResult(null, t);
}
submit(() -> {
try {
callback.complete(affectedReturns(i));
} catch (Throwable t) {
callback.onResult(null, t);
}
});
}
}

Expand Down Expand Up @@ -200,24 +239,26 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee

AtomicReference<T> actualValue = new AtomicReference<>();
AtomicReference<Throwable> actualException = new AtomicReference<>();
AtomicBoolean wasCalled = new AtomicBoolean(false);
CompletableFuture<Void> wasCalledFuture = new CompletableFuture<>();
try {
async.accept((v, e) -> {
actualValue.set(v);
actualException.set(e);
if (wasCalled.get()) {
if (wasCalledFuture.isDone()) {
fail();
}
wasCalled.set(true);
wasCalledFuture.complete(null);
});
} catch (Throwable e) {
fail("async threw instead of using callback");
}

await(wasCalledFuture, "Callback should have been called");

// The following code can be used to debug variations:
// System.out.println("===VARIATION START");
// System.out.println("sync: " + expectedEvents);
// System.out.println("callback called?: " + wasCalled.get());
// System.out.println("callback called?: " + wasCalledFuture.isDone());
// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get());
// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get());
// System.out.println("exception mode: " + (isTestingAbruptCompletion
Expand All @@ -229,7 +270,7 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
throw (AssertionFailedError) actualException.get();
}

assertTrue(wasCalled.get(), "callback should have been called");
assertTrue(wasCalledFuture.isDone(), "callback should have been called");
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());
assertEquals(expectedException == null, actualException.get() == null,
Expand All @@ -242,6 +283,14 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
listener.clear();
}

protected <T> T await(final CompletableFuture<T> voidCompletableFuture, final String errorMessage) {
try {
return voidCompletableFuture.get(1, TimeUnit.MINUTES);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new AssertionError(errorMessage);
}
}

/**
* Tracks invocations: allows testing of all variations of a method calls
*/
Expand Down
Loading