39
39
import java .nio .channels .SocketChannel ;
40
40
import java .nio .channels .WritePendingException ;
41
41
import java .util .Iterator ;
42
- import java .util .concurrent .CancellationException ;
42
+ import java .util .concurrent .ConcurrentHashMap ;
43
43
import java .util .concurrent .ConcurrentLinkedQueue ;
44
44
import java .util .concurrent .CountDownLatch ;
45
45
import java .util .concurrent .ExecutorService ;
@@ -100,19 +100,15 @@ class RegisteredSocket {
100
100
/** Bitwise union of pending operation to be registered in the selector */
101
101
final AtomicInteger pendingOps = new AtomicInteger ();
102
102
103
- RegisteredSocket (TlsChannel tlsChannel , SocketChannel socketChannel )
104
- throws ClosedChannelException {
103
+ RegisteredSocket (TlsChannel tlsChannel , SocketChannel socketChannel ) {
105
104
this .tlsChannel = tlsChannel ;
106
105
this .socketChannel = socketChannel ;
107
106
}
108
107
109
108
public void close () {
110
- doCancelRead (this , null );
111
- doCancelWrite (this , null );
112
109
if (key != null ) {
113
110
key .cancel ();
114
111
}
115
- currentRegistrations .getAndDecrement ();
116
112
/*
117
113
* Actual de-registration from the selector will happen asynchronously.
118
114
*/
@@ -195,8 +191,7 @@ private enum Shutdown {
195
191
private LongAdder cancelledReads = new LongAdder ();
196
192
private LongAdder cancelledWrites = new LongAdder ();
197
193
198
- // used for synchronization
199
- private AtomicInteger currentRegistrations = new AtomicInteger ();
194
+ private final ConcurrentHashMap <RegisteredSocket , Boolean > registrations = new ConcurrentHashMap <>();
200
195
201
196
private LongAdder currentReads = new LongAdder ();
202
197
private LongAdder currentWrites = new LongAdder ();
@@ -232,13 +227,11 @@ public AsynchronousTlsChannelGroup() {
232
227
this (Runtime .getRuntime ().availableProcessors ());
233
228
}
234
229
235
- RegisteredSocket registerSocket (TlsChannel reader , SocketChannel socketChannel )
236
- throws ClosedChannelException {
230
+ RegisteredSocket registerSocket (TlsChannel reader , SocketChannel socketChannel ) {
237
231
if (shutdown != Shutdown .No ) {
238
232
throw new ShutdownChannelGroupException ();
239
233
}
240
234
RegisteredSocket socket = new RegisteredSocket (reader , socketChannel );
241
- currentRegistrations .getAndIncrement ();
242
235
pendingRegistrations .add (socket );
243
236
selector .wakeup ();
244
237
return socket ;
@@ -247,18 +240,13 @@ RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel)
247
240
boolean doCancelRead (RegisteredSocket socket , ReadOperation op ) {
248
241
socket .readLock .lock ();
249
242
try {
250
- // a null op means cancel any operation
251
- if (op != null && socket .readOperation == op || op == null && socket .readOperation != null ) {
252
- if (op == null ) {
253
- socket .readOperation .onFailure .accept (new CancellationException ());
254
- }
255
- socket .readOperation = null ;
256
- cancelledReads .increment ();
257
- currentReads .decrement ();
258
- return true ;
259
- } else {
243
+ if (op != socket .readOperation ) {
260
244
return false ;
261
245
}
246
+ socket .readOperation = null ;
247
+ cancelledReads .increment ();
248
+ currentReads .decrement ();
249
+ return true ;
262
250
} finally {
263
251
socket .readLock .unlock ();
264
252
}
@@ -267,18 +255,13 @@ boolean doCancelRead(RegisteredSocket socket, ReadOperation op) {
267
255
boolean doCancelWrite (RegisteredSocket socket , WriteOperation op ) {
268
256
socket .writeLock .lock ();
269
257
try {
270
- // a null op means cancel any operation
271
- if (op != null && socket .writeOperation == op || op == null && socket .writeOperation != null ) {
272
- if (op == null ) {
273
- socket .writeOperation .onFailure .accept (new CancellationException ());
274
- }
275
- socket .writeOperation = null ;
276
- cancelledWrites .increment ();
277
- currentWrites .decrement ();
278
- return true ;
279
- } else {
258
+ if (op != socket .writeOperation ) {
280
259
return false ;
281
260
}
261
+ socket .writeOperation = null ;
262
+ cancelledWrites .increment ();
263
+ currentWrites .decrement ();
264
+ return true ;
282
265
} finally {
283
266
socket .writeLock .unlock ();
284
267
}
@@ -295,13 +278,23 @@ ReadOperation startRead(
295
278
checkTerminated ();
296
279
Util .assertTrue (buffer .hasRemaining ());
297
280
waitForSocketRegistration (socket );
298
- ReadOperation op ;
299
281
socket .readLock .lock ();
300
282
try {
301
283
if (socket .readOperation != null ) {
302
284
throw new ReadPendingException ();
303
285
}
304
- op = new ReadOperation (buffer , onSuccess , onFailure );
286
+ ReadOperation op = new ReadOperation (buffer , onSuccess , onFailure );
287
+
288
+ startedReads .increment ();
289
+ currentReads .increment ();
290
+
291
+ if (!registrations .containsKey (socket )) {
292
+ op .onFailure .accept (new ClosedChannelException ());
293
+ failedReads .increment ();
294
+ currentReads .decrement ();
295
+ return op ;
296
+ }
297
+
305
298
/*
306
299
* we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
307
300
* operation
@@ -324,9 +317,7 @@ ReadOperation startRead(
324
317
socket .readLock .unlock ();
325
318
}
326
319
selector .wakeup ();
327
- startedReads .increment ();
328
- currentReads .increment ();
329
- return op ;
320
+ return socket .readOperation ;
330
321
}
331
322
332
323
WriteOperation startWrite (
@@ -340,13 +331,23 @@ WriteOperation startWrite(
340
331
checkTerminated ();
341
332
Util .assertTrue (buffer .hasRemaining ());
342
333
waitForSocketRegistration (socket );
343
- WriteOperation op ;
344
334
socket .writeLock .lock ();
345
335
try {
346
336
if (socket .writeOperation != null ) {
347
337
throw new WritePendingException ();
348
338
}
349
- op = new WriteOperation (buffer , onSuccess , onFailure );
339
+ WriteOperation op = new WriteOperation (buffer , onSuccess , onFailure );
340
+
341
+ startedWrites .increment ();
342
+ currentWrites .increment ();
343
+
344
+ if (!registrations .containsKey (socket )) {
345
+ op .onFailure .accept (new ClosedChannelException ());
346
+ failedWrites .increment ();
347
+ currentWrites .decrement ();
348
+ return op ;
349
+ }
350
+
350
351
/*
351
352
* we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
352
353
* operation
@@ -369,9 +370,7 @@ WriteOperation startWrite(
369
370
socket .writeLock .unlock ();
370
371
}
371
372
selector .wakeup ();
372
- startedWrites .increment ();
373
- currentWrites .increment ();
374
- return op ;
373
+ return socket .writeOperation ;
375
374
}
376
375
377
376
private void checkTerminated () {
@@ -391,8 +390,11 @@ private void waitForSocketRegistration(RegisteredSocket socket) {
391
390
private void loop () {
392
391
try {
393
392
while (shutdown == Shutdown .No
394
- || shutdown == Shutdown .Wait && currentRegistrations .intValue () > 0 ) {
395
- int c = selector .select (); // block
393
+ || shutdown == Shutdown .Wait
394
+ && (!pendingRegistrations .isEmpty () || !registrations .isEmpty ())) {
395
+ // most state-changing operations will wake the selector up, however, asynchronous closings
396
+ // of the channels won't, so we have to timeout to allow checking those cases
397
+ int c = selector .select (100 ); // block
396
398
selectionCount .increment ();
397
399
// avoid unnecessary creation of iterator object
398
400
if (c > 0 ) {
@@ -413,24 +415,20 @@ private void loop() {
413
415
}
414
416
registerPendingSockets ();
415
417
processPendingInterests ();
418
+ checkClosings ();
416
419
}
417
420
} catch (Throwable e ) {
418
421
LOGGER .error ("error in selector loop" , e );
419
422
} finally {
420
423
executor .shutdown ();
421
424
// use shutdownNow to stop delayed tasks
422
425
timeoutExecutor .shutdownNow ();
423
- if (shutdown == Shutdown .Immediate ) {
424
- for (SelectionKey key : selector .keys ()) {
425
- RegisteredSocket socket = (RegisteredSocket ) key .attachment ();
426
- socket .close ();
427
- }
428
- }
429
426
try {
430
427
selector .close ();
431
428
} catch (IOException e ) {
432
429
LOGGER .warn ("error closing selector: " + e .getMessage ());
433
430
}
431
+ checkClosings ();
434
432
}
435
433
}
436
434
@@ -606,14 +604,67 @@ private long readHandlingTasks(RegisteredSocket socket, ReadOperation op) throws
606
604
}
607
605
}
608
606
609
- private void registerPendingSockets () throws ClosedChannelException {
607
+ private void registerPendingSockets () {
610
608
RegisteredSocket socket ;
611
609
while ((socket = pendingRegistrations .poll ()) != null ) {
612
- socket .key = socket .socketChannel .register (selector , 0 , socket );
613
- if (LOGGER .isTraceEnabled ()) {
614
- LOGGER .trace ("registered key: " + socket .key );
610
+ try {
611
+ socket .key = socket .socketChannel .register (selector , 0 , socket );
612
+ registrations .put (socket , true );
613
+ } catch (ClosedChannelException e ) {
614
+ // can happen when channels are closed right after creation
615
+ } finally {
616
+ // decrement the count of the latch even in case of exceptions, so the waiting thread
617
+ // is unlocked; it will have to check the result, though
618
+ socket .registered .countDown ();
619
+ }
620
+ }
621
+ }
622
+
623
+ /**
624
+ * Channels that are closed asynchronously are silently removed from selectors. This method will
625
+ * check them using the internal catalog and do the proper cleanup.
626
+ */
627
+ private void checkClosings () {
628
+ for (RegisteredSocket socket : registrations .keySet ()) {
629
+ if (!socket .key .isValid () || shutdown == Shutdown .Immediate ) {
630
+ registrations .remove (socket );
631
+ failCurrentRead (socket );
632
+ failCurrentWrite (socket );
615
633
}
616
- socket .registered .countDown ();
634
+ }
635
+ }
636
+
637
+ private void failCurrentRead (RegisteredSocket socket ) {
638
+ socket .readLock .lock ();
639
+ try {
640
+ if (socket .readOperation != null ) {
641
+ socket .readOperation .onFailure .accept (new ClosedChannelException ());
642
+ if (socket .readOperation .timeoutFuture != null ) {
643
+ socket .readOperation .timeoutFuture .cancel (false );
644
+ }
645
+ socket .readOperation = null ;
646
+ failedReads .increment ();
647
+ currentReads .decrement ();
648
+ }
649
+ } finally {
650
+ socket .readLock .unlock ();
651
+ }
652
+ }
653
+
654
+ private void failCurrentWrite (RegisteredSocket socket ) {
655
+ socket .writeLock .lock ();
656
+ try {
657
+ if (socket .writeOperation != null ) {
658
+ socket .writeOperation .onFailure .accept (new ClosedChannelException ());
659
+ if (socket .writeOperation .timeoutFuture != null ) {
660
+ socket .writeOperation .timeoutFuture .cancel (false );
661
+ }
662
+ socket .writeOperation = null ;
663
+ failedWrites .increment ();
664
+ currentWrites .decrement ();
665
+ }
666
+ } finally {
667
+ socket .writeLock .unlock ();
617
668
}
618
669
}
619
670
@@ -769,6 +820,6 @@ public long getCurrentWriteCount() {
769
820
* @return number of sockets
770
821
*/
771
822
public long getCurrentRegistrationCount () {
772
- return currentRegistrations . longValue ();
823
+ return registrations . mappingCount ();
773
824
}
774
825
}
0 commit comments