Skip to content

Add connection timeout to TLS Channel #1686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
Expand All @@ -49,7 +50,9 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static com.mongodb.assertions.Assertions.assertFalse;
import static com.mongodb.assertions.Assertions.assertTrue;
import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.internal.connection.ServerAddressHelper.getSocketAddresses;
Expand Down Expand Up @@ -99,19 +102,39 @@ public void close() {

private static class SelectorMonitor implements Closeable {

private static final class Pair {
static final class SocketRegistration {
private final SocketChannel socketChannel;
private final Runnable attachment;
private final AtomicReference<ConnectionRegistrationState> connectionRegistrationState;

private Pair(final SocketChannel socketChannel, final Runnable attachment) {
enum ConnectionRegistrationState {
CONNECTING,
CONNECTED,
TIMEOUT_OUT
}

private SocketRegistration(final SocketChannel socketChannel, final Runnable attachment) {
this.socketChannel = socketChannel;
this.attachment = attachment;
this.connectionRegistrationState = new AtomicReference<>(ConnectionRegistrationState.CONNECTING);
}

public boolean markConnectionEstablishmentTimedOut() {
return connectionRegistrationState.compareAndSet(
ConnectionRegistrationState.CONNECTING,
ConnectionRegistrationState.TIMEOUT_OUT);
}

public boolean markConnectionEstablishmentCompleted() {
return connectionRegistrationState.compareAndSet(
ConnectionRegistrationState.CONNECTING,
ConnectionRegistrationState.CONNECTED);
}
}

private final Selector selector;
private volatile boolean isClosed;
private final ConcurrentLinkedDeque<Pair> pendingRegistrations = new ConcurrentLinkedDeque<>();
private final ConcurrentLinkedDeque<SocketRegistration> pendingRegistrations = new ConcurrentLinkedDeque<>();

SelectorMonitor() {
try {
Expand All @@ -121,23 +144,29 @@ private Pair(final SocketChannel socketChannel, final Runnable attachment) {
}
}

// Monitors OP_CONNECT events.
void start() {
Thread selectorThread = new Thread(() -> {
try {
while (!isClosed) {
try {
selector.select();

for (SelectionKey selectionKey : selector.selectedKeys()) {
selectionKey.cancel();
Runnable runnable = (Runnable) selectionKey.attachment();
runnable.run();
SocketRegistration socketRegistration = (SocketRegistration) selectionKey.attachment();

boolean markedCompleted = socketRegistration.markConnectionEstablishmentCompleted();
if (markedCompleted) {
Runnable runnable = socketRegistration.attachment;
runnable.run();
} else {
assertFalse(socketRegistration.socketChannel.isOpen());
}
}

for (Iterator<Pair> iter = pendingRegistrations.iterator(); iter.hasNext();) {
Pair pendingRegistration = iter.next();
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT,
pendingRegistration.attachment);
for (Iterator<SocketRegistration> iter = pendingRegistrations.iterator(); iter.hasNext();) {
SocketRegistration pendingRegistration = iter.next();
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT, pendingRegistration);
iter.remove();
}
} catch (Exception e) {
Expand All @@ -156,8 +185,9 @@ void start() {
selectorThread.start();
}

void register(final SocketChannel channel, final Runnable attachment) {
pendingRegistrations.add(new Pair(channel, attachment));

void register(final SocketRegistration registration) {
pendingRegistrations.add(registration);
selector.wakeup();
}

Expand Down Expand Up @@ -203,41 +233,75 @@ public void openAsync(final OperationContext operationContext, final AsyncComple

socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));

selectorMonitor.register(socketChannel, () -> {
try {
if (!socketChannel.finishConnect()) {
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
}
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
socketChannel, () -> initializeTslChannel(handler, socketChannel));

SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
getServerAddress().getPort());
sslEngine.setUseClientMode(true);
int connectTimeoutMs = getSettings().getConnectTimeout(TimeUnit.MILLISECONDS);

SSLParameters sslParameters = sslEngine.getSSLParameters();
enableSni(getServerAddress().getHost(), sslParameters);
group.getTimeoutExecutor().schedule(() -> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better scalability, I believe we should use a dedicated timeout executor for connection timeouts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't group.getTimeoutExecutor() dedicated to timeouts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is. What I meant is that if an application schedules read and write timeouts at a high rate, a single thread might not keep up and could sometimes overshoot. A dedicated thread for connection timeouts could help offload some work. However, this is a speculative assumption - and one thread might be sufficient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose us to not create any more internal threads than we already have, unless that's either necessary, or we have evidence that the existing ones do not suffice.

boolean markedTimedOut = socketRegistration.markConnectionEstablishmentTimedOut();
if (markedTimedOut) {
closeAndTimeout(handler, socketChannel);
}
}, connectTimeoutMs, TimeUnit.MILLISECONDS);

if (!sslSettings.isInvalidHostNameAllowed()) {
enableHostNameVerification(sslParameters);
}
sslEngine.setSSLParameters(sslParameters);
selectorMonitor.register(socketRegistration);
} catch (IOException e) {
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
handler.failed(t);
}
}

BufferAllocator bufferAllocator = new BufferProviderAllocator();
private void closeAndTimeout(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
// We check if this stream was closed before timeout exception.
boolean streamClosed = isClosed();

TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
.withEncryptedBufferAllocator(bufferAllocator)
.withPlainBufferAllocator(bufferAllocator)
.build();
//TODO refactor ths draft
InterruptedByTimeoutException timeoutException = new InterruptedByTimeoutException();
try {
socketChannel.close();
} catch (Exception e) {
//TODO should ignore this exception? We seem to do so in other places
timeoutException.addSuppressed(e);
}

// build asynchronous channel, based in the TLS channel and associated with the global group.
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));
if (streamClosed) {
handler.completed(null);
} else {
handler.failed(new MongoSocketOpenException("Exception opening socket", getAddress(), new InterruptedByTimeoutException()));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other reviewers:

There is a race condition here. Observing streamClosed being false does not mean that this TlsChannelStream wasn't closed concurrently with the closeAndTimeout method being executed. But that's fine, as it affects only whether we call handler.completed or handler.failed. And since the observer of the handler side effects cannot have an expectation on whether TlsChannelStream.close happens-before the connect timeout expiring, we are free to choose, including an arbitrary choice we do here.

}

handler.completed(null);
} catch (IOException e) {
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
handler.failed(t);
}
});
private void initializeTslChannel(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
try {
if (!socketChannel.finishConnect()) {
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
}

SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
getServerAddress().getPort());
sslEngine.setUseClientMode(true);

SSLParameters sslParameters = sslEngine.getSSLParameters();
enableSni(getServerAddress().getHost(), sslParameters);

if (!sslSettings.isInvalidHostNameAllowed()) {
enableHostNameVerification(sslParameters);
}
sslEngine.setSSLParameters(sslParameters);

BufferAllocator bufferAllocator = new BufferProviderAllocator();

TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
.withEncryptedBufferAllocator(bufferAllocator)
.withPlainBufferAllocator(bufferAllocator)
.build();

// build asynchronous channel, based in the TLS channel and associated with the global group.
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));

handler.completed(null);
} catch (IOException e) {
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -823,4 +823,13 @@ public long getCurrentWriteCount() {
public long getCurrentRegistrationCount() {
return registrations.mappingCount();
}

/**
* Returns the timeout executor used by this channel group.
*
* @return the timeout executor
*/
public ScheduledThreadPoolExecutor getTimeoutExecutor() {
return timeoutExecutor;
}
}