@@ -5,6 +5,7 @@ namespace DotNetty.Handlers.Tls
5
5
{
6
6
using System ;
7
7
using System . Collections . Generic ;
8
+ using System . Diagnostics ;
8
9
using System . Diagnostics . Contracts ;
9
10
using System . IO ;
10
11
using System . Net . Security ;
@@ -41,7 +42,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
41
42
Task < int > pendingSslStreamReadFuture ;
42
43
43
44
public TlsHandler ( TlsSettings settings )
44
- : this ( stream => new SslStream ( stream , false ) , settings )
45
+ : this ( stream => new SslStream ( stream , true ) , settings )
45
46
{
46
47
}
47
48
@@ -69,8 +70,6 @@ public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings
69
70
70
71
bool IsServer => this . settings is ServerTlsSettings ;
71
72
72
- public void Dispose ( ) => this . sslStream ? . Dispose ( ) ;
73
-
74
73
public override void ChannelActive ( IChannelHandlerContext context )
75
74
{
76
75
base . ChannelActive ( context ) ;
@@ -344,6 +343,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
344
343
345
344
outputBuffer = this . pendingSslStreamReadBuffer ;
346
345
outputBufferLength = outputBuffer . WritableBytes ;
346
+
347
+ this . pendingSslStreamReadFuture = null ;
348
+ this . pendingSslStreamReadBuffer = null ;
347
349
}
348
350
else
349
351
{
@@ -363,17 +365,23 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
363
365
if ( ! currentReadFuture . IsCompleted )
364
366
{
365
367
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
366
- Contract . Assert ( this . mediationStream . SourceReadableBytes == 0 ) ;
367
368
368
369
continue ;
369
370
}
370
371
371
372
int read = currentReadFuture . Result ;
372
373
374
+ if ( read == 0 )
375
+ {
376
+ //Stream closed
377
+ return ;
378
+ }
379
+
373
380
// Now output the result of previous read and decide whether to do an extra read on the same source or move forward
374
381
AddBufferToOutput ( outputBuffer , read , output ) ;
375
382
376
383
currentReadFuture = null ;
384
+ outputBuffer = null ;
377
385
if ( this . mediationStream . SourceReadableBytes == 0 )
378
386
{
379
387
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
@@ -620,6 +628,7 @@ void HandleFailure(Exception cause)
620
628
// Release all resources such as internal buffers that SSLEngine
621
629
// is managing.
622
630
631
+ this . mediationStream . Dispose ( ) ;
623
632
try
624
633
{
625
634
this . sslStream . Dispose ( ) ;
@@ -701,14 +710,13 @@ public void ExpandSource(int count)
701
710
702
711
this . inputLength += count ;
703
712
704
- TaskCompletionSource < int > promise = this . readCompletionSource ;
705
- if ( promise == null )
713
+ ArraySegment < byte > sslBuffer = this . sslOwnedBuffer ;
714
+ if ( sslBuffer . Array == null )
706
715
{
707
716
// there is no pending read operation - keep for future
708
717
return ;
709
718
}
710
-
711
- ArraySegment < byte > sslBuffer = this . sslOwnedBuffer ;
719
+ this . sslOwnedBuffer = default ( ArraySegment < byte > ) ;
712
720
713
721
#if NETSTANDARD1_3
714
722
this . readByteCount = this . ReadFromInput ( sslBuffer . Array , sslBuffer . Offset , sslBuffer . Count ) ;
@@ -718,29 +726,35 @@ public void ExpandSource(int count)
718
726
{
719
727
var self = ( MediationStream ) ms ;
720
728
TaskCompletionSource < int > p = self . readCompletionSource ;
721
- this . readCompletionSource = null ;
729
+ self . readCompletionSource = null ;
722
730
p . TrySetResult ( self . readByteCount ) ;
723
731
} ,
724
732
this )
725
733
. RunSynchronously ( TaskScheduler . Default ) ;
726
734
#else
727
735
int read = this . ReadFromInput ( sslBuffer . Array , sslBuffer . Offset , sslBuffer . Count ) ;
736
+
737
+ TaskCompletionSource < int > promise = this . readCompletionSource ;
728
738
this . readCompletionSource = null ;
729
739
promise . TrySetResult ( read ) ;
730
- this . readCallback ? . Invoke ( promise . Task ) ;
740
+
741
+ AsyncCallback callback = this . readCallback ;
742
+ this . readCallback = null ;
743
+ callback ? . Invoke ( promise . Task ) ;
731
744
#endif
732
745
}
733
746
734
747
#if NETSTANDARD1_3
735
748
public override Task < int > ReadAsync ( byte [ ] buffer , int offset , int count , CancellationToken cancellationToken )
736
749
{
737
- if ( this . inputLength - this . inputOffset > 0 )
750
+ if ( this . SourceReadableBytes > 0 )
738
751
{
739
752
// we have the bytes available upfront - write out synchronously
740
753
int read = this . ReadFromInput ( buffer , offset , count ) ;
741
754
return Task . FromResult ( read ) ;
742
755
}
743
756
757
+ Contract . Assert ( this . sslOwnedBuffer . Array == null ) ;
744
758
// take note of buffer - we will pass bytes there once available
745
759
this . sslOwnedBuffer = new ArraySegment < byte > ( buffer , offset , count ) ;
746
760
this . readCompletionSource = new TaskCompletionSource < int > ( ) ;
@@ -749,13 +763,16 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
749
763
#else
750
764
public override IAsyncResult BeginRead ( byte [ ] buffer , int offset , int count , AsyncCallback callback , object state )
751
765
{
752
- if ( this . inputLength - this . inputOffset > 0 )
766
+ if ( this . SourceReadableBytes > 0 )
753
767
{
754
768
// we have the bytes available upfront - write out synchronously
755
769
int read = this . ReadFromInput ( buffer , offset , count ) ;
756
- return this . PrepareSyncReadResult ( read , state ) ;
770
+ var res = this . PrepareSyncReadResult ( read , state ) ;
771
+ callback ? . Invoke ( res ) ;
772
+ return res ;
757
773
}
758
774
775
+ Contract . Assert ( this . sslOwnedBuffer . Array == null ) ;
759
776
// take note of buffer - we will pass bytes there once available
760
777
this . sslOwnedBuffer = new ArraySegment < byte > ( buffer , offset , count ) ;
761
778
this . readCompletionSource = new TaskCompletionSource < int > ( state ) ;
@@ -771,6 +788,7 @@ public override int EndRead(IAsyncResult asyncResult)
771
788
return syncResult . Result ;
772
789
}
773
790
791
+ Debug . Assert ( this . readCompletionSource == null || this . readCompletionSource . Task == asyncResult ) ;
774
792
Contract . Assert ( ! ( ( Task < int > ) asyncResult ) . IsCanceled ) ;
775
793
776
794
try
@@ -782,12 +800,6 @@ public override int EndRead(IAsyncResult asyncResult)
782
800
ExceptionDispatchInfo . Capture ( ex . InnerException ) . Throw ( ) ;
783
801
throw ; // unreachable
784
802
}
785
- finally
786
- {
787
- this . readCompletionSource = null ;
788
- this . readCallback = null ;
789
- this . sslOwnedBuffer = default ( ArraySegment < byte > ) ;
790
- }
791
803
}
792
804
793
805
IAsyncResult PrepareSyncReadResult ( int readBytes , object state )
@@ -817,51 +829,63 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
817
829
// write+flush completed synchronously (and successfully)
818
830
var result = new SynchronousAsyncResult < int > ( ) ;
819
831
result . AsyncState = state ;
820
- callback ( result ) ;
832
+ callback ? . Invoke ( result ) ;
821
833
return result ;
822
834
default :
823
- this . writeCallback = callback ;
824
- var tcs = new TaskCompletionSource ( state ) ;
825
- this . writeCompletion = tcs ;
826
- task . ContinueWith ( WriteCompleteCallback , this , TaskContinuationOptions . ExecuteSynchronously ) ;
827
- return tcs . Task ;
835
+ if ( callback != null || state != task . AsyncState )
836
+ {
837
+ Contract . Assert ( this . writeCompletion == null ) ;
838
+ this . writeCallback = callback ;
839
+ var tcs = new TaskCompletionSource ( state ) ;
840
+ this . writeCompletion = tcs ;
841
+ task . ContinueWith ( WriteCompleteCallback , this , TaskContinuationOptions . ExecuteSynchronously ) ;
842
+ return tcs . Task ;
843
+ }
844
+ else
845
+ {
846
+ return task ;
847
+ }
828
848
}
829
849
}
830
850
831
851
static void HandleChannelWriteComplete ( Task writeTask , object state )
832
852
{
833
853
var self = ( MediationStream ) state ;
854
+
855
+ AsyncCallback callback = self . writeCallback ;
856
+ self . writeCallback = null ;
857
+
858
+ var promise = self . writeCompletion ;
859
+ self . writeCompletion = null ;
860
+
834
861
switch ( writeTask . Status )
835
862
{
836
863
case TaskStatus . RanToCompletion :
837
- self . writeCompletion . TryComplete ( ) ;
864
+ promise . TryComplete ( ) ;
838
865
break ;
839
866
case TaskStatus . Canceled :
840
- self . writeCompletion . TrySetCanceled ( ) ;
867
+ promise . TrySetCanceled ( ) ;
841
868
break ;
842
869
case TaskStatus . Faulted :
843
- self . writeCompletion . TrySetException ( writeTask . Exception ) ;
870
+ promise . TrySetException ( writeTask . Exception ) ;
844
871
break ;
845
872
default :
846
873
throw new ArgumentOutOfRangeException ( "Unexpected task status: " + writeTask . Status ) ;
847
874
}
848
875
849
- self . writeCallback ? . Invoke ( self . writeCompletion . Task ) ;
876
+ callback ? . Invoke ( promise . Task ) ;
850
877
}
851
878
852
879
public override void EndWrite ( IAsyncResult asyncResult )
853
880
{
854
- this . writeCallback = null ;
855
- this . writeCompletion = null ;
856
-
857
881
if ( asyncResult is SynchronousAsyncResult < int > )
858
882
{
859
883
return ;
860
884
}
861
885
862
886
try
863
887
{
864
- ( ( Task < int > ) asyncResult ) . Wait ( ) ;
888
+ ( ( Task ) asyncResult ) . Wait ( ) ;
865
889
}
866
890
catch ( AggregateException ex )
867
891
{
@@ -876,7 +900,7 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa
876
900
Contract . Assert ( destination != null ) ;
877
901
878
902
byte [ ] source = this . input ;
879
- int readableBytes = this . inputLength - this . inputOffset ;
903
+ int readableBytes = this . SourceReadableBytes ;
880
904
int length = Math . Min ( readableBytes , destinationCapacity ) ;
881
905
Buffer . BlockCopy ( source , this . inputStartOffset + this . inputOffset , destination , destinationOffset , length ) ;
882
906
this . inputOffset += length ;
@@ -894,8 +918,11 @@ protected override void Dispose(bool disposing)
894
918
if ( disposing )
895
919
{
896
920
TaskCompletionSource < int > p = this . readCompletionSource ;
897
- this . readCompletionSource = null ;
898
- p ? . TrySetResult ( 0 ) ;
921
+ if ( p != null )
922
+ {
923
+ this . readCompletionSource = null ;
924
+ p . TrySetResult ( 0 ) ;
925
+ }
899
926
}
900
927
}
901
928
0 commit comments