Skip to content

Commit eec0eed

Browse files
authored
Merge changes from tls-channel for race condition manifested when closing async sockets right after creation (#851)
This is a backport of #848 JAVA-4417
1 parent 3084f18 commit eec0eed

File tree

1 file changed

+106
-55
lines changed

1 file changed

+106
-55
lines changed

driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java

Lines changed: 106 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import java.nio.channels.SocketChannel;
4040
import java.nio.channels.WritePendingException;
4141
import java.util.Iterator;
42-
import java.util.concurrent.CancellationException;
42+
import java.util.concurrent.ConcurrentHashMap;
4343
import java.util.concurrent.ConcurrentLinkedQueue;
4444
import java.util.concurrent.CountDownLatch;
4545
import java.util.concurrent.ExecutorService;
@@ -100,19 +100,15 @@ class RegisteredSocket {
100100
/** Bitwise union of pending operation to be registered in the selector */
101101
final AtomicInteger pendingOps = new AtomicInteger();
102102

103-
RegisteredSocket(TlsChannel tlsChannel, SocketChannel socketChannel)
104-
throws ClosedChannelException {
103+
RegisteredSocket(TlsChannel tlsChannel, SocketChannel socketChannel) {
105104
this.tlsChannel = tlsChannel;
106105
this.socketChannel = socketChannel;
107106
}
108107

109108
public void close() {
110-
doCancelRead(this, null);
111-
doCancelWrite(this, null);
112109
if (key != null) {
113110
key.cancel();
114111
}
115-
currentRegistrations.getAndDecrement();
116112
/*
117113
* Actual de-registration from the selector will happen asynchronously.
118114
*/
@@ -195,8 +191,7 @@ private enum Shutdown {
195191
private LongAdder cancelledReads = new LongAdder();
196192
private LongAdder cancelledWrites = new LongAdder();
197193

198-
// used for synchronization
199-
private AtomicInteger currentRegistrations = new AtomicInteger();
194+
private final ConcurrentHashMap<RegisteredSocket, Boolean> registrations = new ConcurrentHashMap<>();
200195

201196
private LongAdder currentReads = new LongAdder();
202197
private LongAdder currentWrites = new LongAdder();
@@ -232,13 +227,11 @@ public AsynchronousTlsChannelGroup() {
232227
this(Runtime.getRuntime().availableProcessors());
233228
}
234229

235-
RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel)
236-
throws ClosedChannelException {
230+
RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel) {
237231
if (shutdown != Shutdown.No) {
238232
throw new ShutdownChannelGroupException();
239233
}
240234
RegisteredSocket socket = new RegisteredSocket(reader, socketChannel);
241-
currentRegistrations.getAndIncrement();
242235
pendingRegistrations.add(socket);
243236
selector.wakeup();
244237
return socket;
@@ -247,18 +240,13 @@ RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel)
247240
boolean doCancelRead(RegisteredSocket socket, ReadOperation op) {
248241
socket.readLock.lock();
249242
try {
250-
// a null op means cancel any operation
251-
if (op != null && socket.readOperation == op || op == null && socket.readOperation != null) {
252-
if (op == null) {
253-
socket.readOperation.onFailure.accept(new CancellationException());
254-
}
255-
socket.readOperation = null;
256-
cancelledReads.increment();
257-
currentReads.decrement();
258-
return true;
259-
} else {
243+
if (op != socket.readOperation) {
260244
return false;
261245
}
246+
socket.readOperation = null;
247+
cancelledReads.increment();
248+
currentReads.decrement();
249+
return true;
262250
} finally {
263251
socket.readLock.unlock();
264252
}
@@ -267,18 +255,13 @@ boolean doCancelRead(RegisteredSocket socket, ReadOperation op) {
267255
boolean doCancelWrite(RegisteredSocket socket, WriteOperation op) {
268256
socket.writeLock.lock();
269257
try {
270-
// a null op means cancel any operation
271-
if (op != null && socket.writeOperation == op || op == null && socket.writeOperation != null) {
272-
if (op == null) {
273-
socket.writeOperation.onFailure.accept(new CancellationException());
274-
}
275-
socket.writeOperation = null;
276-
cancelledWrites.increment();
277-
currentWrites.decrement();
278-
return true;
279-
} else {
258+
if (op != socket.writeOperation) {
280259
return false;
281260
}
261+
socket.writeOperation = null;
262+
cancelledWrites.increment();
263+
currentWrites.decrement();
264+
return true;
282265
} finally {
283266
socket.writeLock.unlock();
284267
}
@@ -295,13 +278,23 @@ ReadOperation startRead(
295278
checkTerminated();
296279
Util.assertTrue(buffer.hasRemaining());
297280
waitForSocketRegistration(socket);
298-
ReadOperation op;
299281
socket.readLock.lock();
300282
try {
301283
if (socket.readOperation != null) {
302284
throw new ReadPendingException();
303285
}
304-
op = new ReadOperation(buffer, onSuccess, onFailure);
286+
ReadOperation op = new ReadOperation(buffer, onSuccess, onFailure);
287+
288+
startedReads.increment();
289+
currentReads.increment();
290+
291+
if (!registrations.containsKey(socket)) {
292+
op.onFailure.accept(new ClosedChannelException());
293+
failedReads.increment();
294+
currentReads.decrement();
295+
return op;
296+
}
297+
305298
/*
306299
* we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
307300
* operation
@@ -324,9 +317,7 @@ ReadOperation startRead(
324317
socket.readLock.unlock();
325318
}
326319
selector.wakeup();
327-
startedReads.increment();
328-
currentReads.increment();
329-
return op;
320+
return socket.readOperation;
330321
}
331322

332323
WriteOperation startWrite(
@@ -340,13 +331,23 @@ WriteOperation startWrite(
340331
checkTerminated();
341332
Util.assertTrue(buffer.hasRemaining());
342333
waitForSocketRegistration(socket);
343-
WriteOperation op;
344334
socket.writeLock.lock();
345335
try {
346336
if (socket.writeOperation != null) {
347337
throw new WritePendingException();
348338
}
349-
op = new WriteOperation(buffer, onSuccess, onFailure);
339+
WriteOperation op = new WriteOperation(buffer, onSuccess, onFailure);
340+
341+
startedWrites.increment();
342+
currentWrites.increment();
343+
344+
if (!registrations.containsKey(socket)) {
345+
op.onFailure.accept(new ClosedChannelException());
346+
failedWrites.increment();
347+
currentWrites.decrement();
348+
return op;
349+
}
350+
350351
/*
351352
* we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
352353
* operation
@@ -369,9 +370,7 @@ WriteOperation startWrite(
369370
socket.writeLock.unlock();
370371
}
371372
selector.wakeup();
372-
startedWrites.increment();
373-
currentWrites.increment();
374-
return op;
373+
return socket.writeOperation;
375374
}
376375

377376
private void checkTerminated() {
@@ -391,8 +390,11 @@ private void waitForSocketRegistration(RegisteredSocket socket) {
391390
private void loop() {
392391
try {
393392
while (shutdown == Shutdown.No
394-
|| shutdown == Shutdown.Wait && currentRegistrations.intValue() > 0) {
395-
int c = selector.select(); // block
393+
|| shutdown == Shutdown.Wait
394+
&& (!pendingRegistrations.isEmpty() || !registrations.isEmpty())) {
395+
// most state-changing operations will wake the selector up, however, asynchronous closings
396+
// of the channels won't, so we have to timeout to allow checking those cases
397+
int c = selector.select(100); // block
396398
selectionCount.increment();
397399
// avoid unnecessary creation of iterator object
398400
if (c > 0) {
@@ -413,24 +415,20 @@ private void loop() {
413415
}
414416
registerPendingSockets();
415417
processPendingInterests();
418+
checkClosings();
416419
}
417420
} catch (Throwable e) {
418421
LOGGER.error("error in selector loop", e);
419422
} finally {
420423
executor.shutdown();
421424
// use shutdownNow to stop delayed tasks
422425
timeoutExecutor.shutdownNow();
423-
if (shutdown == Shutdown.Immediate) {
424-
for (SelectionKey key : selector.keys()) {
425-
RegisteredSocket socket = (RegisteredSocket) key.attachment();
426-
socket.close();
427-
}
428-
}
429426
try {
430427
selector.close();
431428
} catch (IOException e) {
432429
LOGGER.warn("error closing selector: " + e.getMessage());
433430
}
431+
checkClosings();
434432
}
435433
}
436434

@@ -606,14 +604,67 @@ private long readHandlingTasks(RegisteredSocket socket, ReadOperation op) throws
606604
}
607605
}
608606

609-
private void registerPendingSockets() throws ClosedChannelException {
607+
private void registerPendingSockets() {
610608
RegisteredSocket socket;
611609
while ((socket = pendingRegistrations.poll()) != null) {
612-
socket.key = socket.socketChannel.register(selector, 0, socket);
613-
if (LOGGER.isTraceEnabled()) {
614-
LOGGER.trace("registered key: " + socket.key);
610+
try {
611+
socket.key = socket.socketChannel.register(selector, 0, socket);
612+
registrations.put(socket, true);
613+
} catch (ClosedChannelException e) {
614+
// can happen when channels are closed right after creation
615+
} finally {
616+
// decrement the count of the latch even in case of exceptions, so the waiting thread
617+
// is unlocked; it will have to check the result, though
618+
socket.registered.countDown();
619+
}
620+
}
621+
}
622+
623+
/**
624+
* Channels that are closed asynchronously are silently removed from selectors. This method will
625+
* check them using the internal catalog and do the proper cleanup.
626+
*/
627+
private void checkClosings() {
628+
for (RegisteredSocket socket : registrations.keySet()) {
629+
if (!socket.key.isValid() || shutdown == Shutdown.Immediate) {
630+
registrations.remove(socket);
631+
failCurrentRead(socket);
632+
failCurrentWrite(socket);
615633
}
616-
socket.registered.countDown();
634+
}
635+
}
636+
637+
private void failCurrentRead(RegisteredSocket socket) {
638+
socket.readLock.lock();
639+
try {
640+
if (socket.readOperation != null) {
641+
socket.readOperation.onFailure.accept(new ClosedChannelException());
642+
if (socket.readOperation.timeoutFuture != null) {
643+
socket.readOperation.timeoutFuture.cancel(false);
644+
}
645+
socket.readOperation = null;
646+
failedReads.increment();
647+
currentReads.decrement();
648+
}
649+
} finally {
650+
socket.readLock.unlock();
651+
}
652+
}
653+
654+
private void failCurrentWrite(RegisteredSocket socket) {
655+
socket.writeLock.lock();
656+
try {
657+
if (socket.writeOperation != null) {
658+
socket.writeOperation.onFailure.accept(new ClosedChannelException());
659+
if (socket.writeOperation.timeoutFuture != null) {
660+
socket.writeOperation.timeoutFuture.cancel(false);
661+
}
662+
socket.writeOperation = null;
663+
failedWrites.increment();
664+
currentWrites.decrement();
665+
}
666+
} finally {
667+
socket.writeLock.unlock();
617668
}
618669
}
619670

@@ -769,6 +820,6 @@ public long getCurrentWriteCount() {
769820
* @return number of sockets
770821
*/
771822
public long getCurrentRegistrationCount() {
772-
return currentRegistrations.longValue();
823+
return registrations.mappingCount();
773824
}
774825
}

0 commit comments

Comments
 (0)