Skip to content

Commit 21ec1c5

Browse files
committed
If the stream passed to Http1Connection does not support gathered writes, add a WriteBufferingStream on top of it.
1 parent 59b936e commit 21ec1c5

10 files changed

+470
-73
lines changed

NetworkToolkit/Connections/Connection.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ public async ValueTask DisposeAsync(CancellationToken cancellationToken)
6161

6262
_stream = null;
6363

64-
await (stream is ICancellableAsyncDisposable cancellable ?
65-
cancellable.DisposeAsync(cancellationToken) :
66-
stream.DisposeAsync()).ConfigureAwait(false);
64+
await stream.DisposeAsync(cancellationToken).ConfigureAwait(false);
6765
}
6866

6967
/// <summary>

NetworkToolkit/Connections/FilteringConnection.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,17 @@ public FilteringConnection(Connection baseConnection, Stream stream) : base(stre
3636
}
3737

3838
/// <inheritdoc/>
39-
protected override ValueTask DisposeAsyncCore(CancellationToken cancellationToken) =>
40-
BaseConnection.DisposeAsync(cancellationToken);
39+
protected override async ValueTask DisposeAsyncCore(CancellationToken cancellationToken)
40+
{
41+
await Stream.DisposeAsync(cancellationToken).ConfigureAwait(false);
42+
await BaseConnection.DisposeAsync(cancellationToken).ConfigureAwait(false);
43+
}
4144

4245
/// <inheritdoc/>
43-
public override ValueTask CompleteWritesAsync(CancellationToken cancellationToken) =>
44-
BaseConnection.CompleteWritesAsync(cancellationToken);
46+
public override async ValueTask CompleteWritesAsync(CancellationToken cancellationToken)
47+
{
48+
await Stream.FlushAsync(cancellationToken).ConfigureAwait(false);
49+
await BaseConnection.CompleteWritesAsync(cancellationToken).ConfigureAwait(false);
50+
}
4551
}
4652
}

NetworkToolkit/Connections/MemoryConnection.cs

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public override ValueTask CompleteWritesAsync(CancellationToken cancellationToke
5656

5757
private sealed class MemoryConnectionStream : Stream, IGatheringStream
5858
{
59-
PipeReader _reader;
59+
PipeReader? _reader;
6060
PipeWriter? _writer;
6161

6262
public override bool CanRead => true;
@@ -77,12 +77,12 @@ public MemoryConnectionStream(PipeReader reader, PipeWriter writer)
7777

7878
protected override void Dispose(bool disposing)
7979
{
80-
if (disposing && _writer != null)
80+
if (disposing && _reader is PipeReader reader)
8181
{
82-
_writer.Complete();
83-
_reader.Complete();
84-
_writer = null!;
85-
_reader = null!;
82+
_writer?.Complete();
83+
reader.Complete();
84+
_writer = null;
85+
_reader = null;
8686
}
8787
}
8888

@@ -106,13 +106,13 @@ public override int Read(byte[] buffer, int offset, int count)
106106

107107
public override int Read(Span<byte> buffer)
108108
{
109-
if (_reader == null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
109+
if (_reader is not PipeReader reader) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
110110

111111
try
112112
{
113-
return FinishRead(buffer, Tools.BlockForResult(_reader.ReadAsync()));
113+
return FinishRead(reader, buffer, Tools.BlockForResult(_reader.ReadAsync()), CancellationToken.None);
114114
}
115-
catch (Exception ex)
115+
catch (Exception ex) when (ex is not OperationCanceledException)
116116
{
117117
throw new IOException(ex.Message, ex);
118118
}
@@ -129,24 +129,25 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
129129

130130
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
131131
{
132-
if (_reader == null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
132+
if (_reader is not PipeReader reader) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
133133

134134
try
135135
{
136-
ReadResult result = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
137-
return FinishRead(buffer.Span, result);
136+
ReadResult result = await reader.ReadAsync(cancellationToken).ConfigureAwait(false);
137+
return FinishRead(reader, buffer.Span, result, cancellationToken);
138138
}
139-
catch (Exception ex)
139+
catch (Exception ex) when(ex is not OperationCanceledException)
140140
{
141141
throw new IOException(ex.Message, ex);
142142
}
143143
}
144144

145-
private int FinishRead(Span<byte> buffer, in ReadResult result)
145+
private static int FinishRead(PipeReader reader, Span<byte> buffer, in ReadResult result, CancellationToken cancellationToken)
146146
{
147147
if (result.IsCanceled)
148148
{
149-
throw new SocketException((int)SocketError.OperationAborted);
149+
cancellationToken.ThrowIfCancellationRequested();
150+
throw new OperationCanceledException();
150151
}
151152

152153
ReadOnlySequence<byte> sequence = result.Buffer;
@@ -171,7 +172,7 @@ private int FinishRead(Span<byte> buffer, in ReadResult result)
171172
}
172173
finally
173174
{
174-
_reader.AdvanceTo(consumed);
175+
reader.AdvanceTo(consumed);
175176
}
176177
}
177178

@@ -180,7 +181,8 @@ public override void Write(byte[] buffer, int offset, int count) =>
180181

181182
public override void Write(ReadOnlySpan<byte> buffer)
182183
{
183-
if (_writer == null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
184+
if (_reader == null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
185+
if (_writer == null) throw new InvalidOperationException($"{nameof(MemoryConnectionStream)} cannot be written to after writes have been completed.");
184186

185187
try
186188
{
@@ -191,10 +193,10 @@ public override void Write(ReadOnlySpan<byte> buffer)
191193

192194
if (res.IsCanceled)
193195
{
194-
throw new SocketException((int)SocketError.OperationAborted);
196+
throw new OperationCanceledException();
195197
}
196198
}
197-
catch (Exception ex)
199+
catch (Exception ex) when (ex is not OperationCanceledException)
198200
{
199201
throw new IOException(ex.Message, ex);
200202
}
@@ -205,48 +207,58 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
205207

206208
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
207209
{
208-
if (_writer == null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
210+
if (_reader is null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
211+
if (_writer is not PipeWriter writer) throw new InvalidOperationException($"{nameof(MemoryConnectionStream)} cannot be written to after writes have been completed.");
209212

210213
try
211214
{
212-
FlushResult res = await _writer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
215+
FlushResult res = await writer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
213216

214217
if (res.IsCanceled)
215218
{
216-
throw new SocketException((int)SocketError.OperationAborted);
219+
cancellationToken.ThrowIfCancellationRequested();
220+
throw new OperationCanceledException();
217221
}
218222
}
219-
catch (Exception ex)
223+
catch (Exception ex) when (ex is not OperationCanceledException)
220224
{
221225
throw new IOException(ex.Message, ex);
222226
}
223227
}
224228

225229
public async ValueTask WriteAsync(IReadOnlyList<ReadOnlyMemory<byte>> buffers, CancellationToken cancellationToken = default)
226230
{
227-
if (_writer == null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
231+
if (_reader == null) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
232+
if (_writer is not PipeWriter writer) throw new InvalidOperationException($"{nameof(MemoryConnectionStream)} cannot be written to after writes have been completed.");
228233

229234
try
230235
{
231236
foreach (ReadOnlyMemory<byte> buffer in buffers)
232237
{
233-
buffer.Span.CopyTo(_writer.GetSpan(buffer.Length));
234-
_writer.Advance(buffer.Length);
238+
buffer.Span.CopyTo(writer.GetSpan(buffer.Length));
239+
writer.Advance(buffer.Length);
235240
}
236241

237-
FlushResult res = await _writer.FlushAsync(cancellationToken).ConfigureAwait(false);
242+
FlushResult res = await writer.FlushAsync(cancellationToken).ConfigureAwait(false);
238243

239244
if (res.IsCanceled)
240245
{
241-
throw new SocketException((int)SocketError.OperationAborted);
246+
cancellationToken.ThrowIfCancellationRequested();
247+
throw new OperationCanceledException();
242248
}
243249
}
244-
catch (Exception ex)
250+
catch (Exception ex) when (ex is not OperationCanceledException)
245251
{
246252
throw new IOException(ex.Message, ex);
247253
}
248254
}
249255

256+
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) =>
257+
TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
258+
259+
public override void EndWrite(IAsyncResult asyncResult) =>
260+
TaskToApm.End(asyncResult);
261+
250262
public override void Flush()
251263
{
252264
}
@@ -265,6 +277,22 @@ public override void SetLength(long value)
265277
{
266278
throw new NotImplementedException();
267279
}
280+
281+
public override void CopyTo(Stream destination, int bufferSize) =>
282+
CopyToAsync(destination, bufferSize, CancellationToken.None).GetAwaiter().GetResult();
283+
284+
public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
285+
{
286+
if (_reader is not PipeReader reader) throw new ObjectDisposedException(nameof(MemoryConnectionStream));
287+
try
288+
{
289+
await reader.CopyToAsync(destination, cancellationToken).ConfigureAwait(false);
290+
}
291+
catch (Exception ex) when (ex is not OperationCanceledException)
292+
{
293+
throw new IOException(ex.Message, ex);
294+
}
295+
}
268296
}
269297
}
270298
}

NetworkToolkit/Connections/SocketConnectionFactory.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,15 @@ private sealed class SocketConnection : Connection
219219

220220
private Socket Socket => ((NetworkStream)Stream).Socket;
221221

222-
public SocketConnection(Socket socket) : base(new GatheringNetworkStream(socket))
222+
public SocketConnection(Socket socket) : base(CreateStream(socket))
223223
{
224224
}
225225

226+
private static NetworkStream CreateStream(Socket socket) =>
227+
GatheringNetworkStream.IsSupported
228+
? new GatheringNetworkStream(socket)
229+
: new NetworkStream(socket);
230+
226231
protected override ValueTask DisposeAsyncCore(CancellationToken cancellationToken)
227232
=> default;
228233

NetworkToolkit/Connections/SslConnectionFactory.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,6 @@ public SslConnection(Connection baseConnection, SslStream stream) : base(baseCon
118118
{
119119
}
120120

121-
protected override async ValueTask DisposeAsyncCore(CancellationToken cancellationToken)
122-
{
123-
await Stream.DisposeAsync().ConfigureAwait(false);
124-
await base.DisposeAsyncCore(cancellationToken).ConfigureAwait(false);
125-
}
126-
127121
public override bool TryGetProperty(Type type, out object? value)
128122
{
129123
if (type == typeof(SslStream))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using System.Net;
2+
using System.Threading;
3+
using System.Threading.Tasks;
4+
5+
namespace NetworkToolkit.Connections
6+
{
7+
/// <summary>
8+
/// A connection factory that adds write buffering to underlying connections.
9+
/// </summary>
10+
public sealed class WriteBufferingConnectionFactory : FilteringConnectionFactory
11+
{
12+
/// <summary>
13+
/// Instantiates a new <see cref="WriteBufferingConnectionFactory"/>.
14+
/// </summary>
15+
/// <param name="baseFactory">The underlying factory that will have write buffering added to its connections.</param>
16+
public WriteBufferingConnectionFactory(ConnectionFactory baseFactory) : base(baseFactory)
17+
{
18+
}
19+
20+
/// <inheritdoc/>
21+
public override async ValueTask<Connection> ConnectAsync(EndPoint endPoint, IConnectionProperties? options = null, CancellationToken cancellationToken = default)
22+
{
23+
Connection con = await BaseFactory.ConnectAsync(endPoint, options, cancellationToken).ConfigureAwait(false);
24+
return new FilteringConnection(con, new WriteBufferingStream(con.Stream));
25+
}
26+
27+
/// <inheritdoc/>
28+
public override async ValueTask<ConnectionListener> ListenAsync(EndPoint? endPoint = null, IConnectionProperties? options = null, CancellationToken cancellationToken = default)
29+
{
30+
ConnectionListener listener = await BaseFactory.ListenAsync(endPoint, options, cancellationToken).ConfigureAwait(false);
31+
return new WriteBufferingConnectionListener(listener);
32+
}
33+
34+
private sealed class WriteBufferingConnectionListener : FilteringConnectionListener
35+
{
36+
public WriteBufferingConnectionListener(ConnectionListener baseListener) : base(baseListener)
37+
{
38+
}
39+
40+
public override async ValueTask<Connection?> AcceptConnectionAsync(IConnectionProperties? options = null, CancellationToken cancellationToken = default)
41+
{
42+
Connection? con = await BaseListener.AcceptConnectionAsync(options, cancellationToken).ConfigureAwait(false);
43+
if (con == null) return con;
44+
45+
return new FilteringConnection(con, new WriteBufferingStream(con.Stream));
46+
}
47+
}
48+
}
49+
}

NetworkToolkit/GatheringNetworkStream.cs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using System;
22
using System.Buffers;
33
using System.Collections.Generic;
4+
using System.Diagnostics;
45
using System.IO;
56
using System.Net.Sockets;
7+
using System.Reflection;
68
using System.Runtime.ExceptionServices;
79
using System.Runtime.InteropServices;
810
using System.Threading;
@@ -13,9 +15,25 @@ namespace NetworkToolkit
1315
internal sealed class GatheringNetworkStream : NetworkStream, IGatheringStream
1416
{
1517
private EventArgs? _gatheredEventArgs;
18+
private static Func<Socket, SocketAsyncEventArgs, CancellationToken, bool>? s_sendAsyncWithCancellation;
19+
20+
public static bool IsSupported => s_sendAsyncWithCancellation != null;
21+
22+
static GatheringNetworkStream()
23+
{
24+
MethodInfo? sendAsync = typeof(Socket).GetMethod("SendAsync", BindingFlags.NonPublic | BindingFlags.Instance, binder: null, new[] { typeof(SocketAsyncEventArgs), typeof(CancellationToken) }, modifiers: null);
25+
26+
if (sendAsync != null)
27+
{
28+
s_sendAsyncWithCancellation =
29+
(Func<Socket, SocketAsyncEventArgs, CancellationToken, bool>)
30+
Delegate.CreateDelegate(typeof(Func<Socket, SocketAsyncEventArgs, CancellationToken, bool>), firstArgument: null, sendAsync);
31+
}
32+
}
1633

1734
public GatheringNetworkStream(Socket socket) : base(socket, ownsSocket: true)
1835
{
36+
Debug.Assert(IsSupported);
1937
}
2038

2139
protected override void Dispose(bool disposing)
@@ -35,6 +53,12 @@ public ValueTask WriteAsync(IReadOnlyList<ReadOnlyMemory<byte>> buffers, Cancell
3553
return WriteAsync(buffers[0], cancellationToken);
3654
}
3755

56+
if (cancellationToken.IsCancellationRequested)
57+
{
58+
// There is no SocketAsyncEventArgs call that is cancellable...
59+
return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(cancellationToken)));
60+
}
61+
3862
_gatheredEventArgs ??= new EventArgs();
3963
return _gatheredEventArgs.WriteAsync(Socket, buffers, cancellationToken);
4064
}
@@ -44,7 +68,7 @@ private sealed class EventArgs : SocketTaskEventArgs<int>
4468
private List<ArraySegment<byte>>? _gatheredSegments;
4569
private List<byte[]>? _pooledArrays;
4670

47-
public ValueTask WriteAsync(Socket socket, IReadOnlyList<ReadOnlyMemory<byte>> buffers, CancellationToken cancellationToken = default)
71+
public ValueTask WriteAsync(Socket socket, IReadOnlyList<ReadOnlyMemory<byte>> buffers, CancellationToken cancellationToken)
4872
{
4973
int bufferCount = buffers.Count;
5074

@@ -71,7 +95,7 @@ public ValueTask WriteAsync(Socket socket, IReadOnlyList<ReadOnlyMemory<byte>> b
7195

7296
BufferList = _gatheredSegments;
7397
Reset();
74-
if (!socket.SendAsync(this))
98+
if (!s_sendAsyncWithCancellation!(socket, this, cancellationToken))
7599
{
76100
OnCompleted();
77101
}
@@ -104,7 +128,8 @@ public void OnCompleted()
104128
}
105129
else
106130
{
107-
SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException($"{nameof(WriteAsync)} failed. See InnerException for more details.", new SocketException((int)SocketError))));
131+
var ex = new SocketException((int)SocketError);
132+
SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException(ex.Message, ex)));
108133
}
109134
}
110135
}

0 commit comments

Comments
 (0)