Skip to content

Commit 8093c52

Browse files
geoffkizerGeoffrey Kizer
andauthored
add Task-based DisconnectAsync and refactor APM methods on top of it (#51213)
* add Task-based DisconnectAsync and refactor APM methods on top of it * fix BeginDisconnect to throw synchronously and add relevant tests * remove #region stuff in Socket.cs and add link to github issue Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
1 parent d2daf0b commit 8093c52

12 files changed

+236
-267
lines changed

src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ public void Connect(string host, int port) { }
338338
public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
339339
public void Disconnect(bool reuseSocket) { }
340340
public bool DisconnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
341+
public System.Threading.Tasks.ValueTask DisconnectAsync(bool reuseSocket, System.Threading.CancellationToken cancellationToken = default) { throw null; }
341342
public void Dispose() { }
342343
protected virtual void Dispose(bool disposing) { }
343344
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]

src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
<Compile Include="System\Net\Sockets\UdpReceiveResult.cs" />
4646
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.cs" />
4747
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.cs" />
48-
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.cs" />
4948
<Compile Include="System\Net\Sockets\UnixDomainSocketEndPoint.cs" />
5049
<!-- Logging -->
5150
<Compile Include="$(CommonPath)System\Net\Logging\NetEventSource.Common.cs"
@@ -187,7 +186,6 @@
187186
<ItemGroup Condition="'$(TargetsUnix)' == 'true'">
188187
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.Unix.cs" />
189188
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.Unix.cs" />
190-
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.Unix.cs" />
191189
<Compile Include="System\Net\Sockets\SafeSocketHandle.Unix.cs" />
192190
<Compile Include="System\Net\Sockets\Socket.Unix.cs" />
193191
<Compile Include="System\Net\Sockets\SocketAsyncContext.Unix.cs" />

src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.Unix.cs

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.cs

Lines changed: 0 additions & 27 deletions
This file was deleted.

src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,29 @@ public ValueTask ConnectAsync(string host, int port, CancellationToken cancellat
263263
return ConnectAsync(ep, cancellationToken);
264264
}
265265

266+
/// <summary>
267+
/// Disconnects a connected socket from the remote host.
268+
/// </summary>
269+
/// <param name="reuseSocket">Indicates whether the socket should be available for reuse after disconnect.</param>
270+
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
271+
/// <returns>An asynchronous task that completes when the socket is disconnected.</returns>
272+
public ValueTask DisconnectAsync(bool reuseSocket, CancellationToken cancellationToken = default)
273+
{
274+
if (cancellationToken.IsCancellationRequested)
275+
{
276+
return ValueTask.FromCanceled(cancellationToken);
277+
}
278+
279+
AwaitableSocketAsyncEventArgs saea =
280+
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
281+
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);
282+
283+
saea.DisconnectReuseSocket = reuseSocket;
284+
saea.WrapExceptionsForNetworkStream = false;
285+
286+
return saea.DisconnectAsync(this, cancellationToken);
287+
}
288+
266289
/// <summary>
267290
/// Receives data from a connected socket.
268291
/// </summary>
@@ -1028,6 +1051,25 @@ public ValueTask ConnectAsync(Socket socket)
10281051
ValueTask.FromException(CreateException(error));
10291052
}
10301053

1054+
public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken)
1055+
{
1056+
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");
1057+
1058+
if (socket.DisconnectAsync(this, cancellationToken))
1059+
{
1060+
_cancellationToken = cancellationToken;
1061+
return new ValueTask(this, _token);
1062+
}
1063+
1064+
SocketError error = SocketError;
1065+
1066+
Release();
1067+
1068+
return error == SocketError.Success ?
1069+
ValueTask.CompletedTask :
1070+
ValueTask.FromException(CreateException(error));
1071+
}
1072+
10311073
/// <summary>Gets the status of the operation.</summary>
10321074
public ValueTaskSourceStatus GetStatus(short token)
10331075
{

src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs

Lines changed: 36 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ private sealed class CacheSet
7373
private int _closeTimeout = Socket.DefaultCloseTimeout;
7474
private int _disposed; // 0 == false, anything else == true
7575

76-
#region Constructors
7776
public Socket(SocketType socketType, ProtocolType protocolType)
7877
: this(OSSupportsIPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, socketType, protocolType)
7978
{
@@ -242,9 +241,10 @@ private static SafeSocketHandle ValidateHandle(SafeSocketHandle handle) =>
242241
handle is null ? throw new ArgumentNullException(nameof(handle)) :
243242
handle.IsInvalid ? throw new ArgumentException(SR.Arg_InvalidHandle, nameof(handle)) :
244243
handle;
245-
#endregion
246244

247-
#region Properties
245+
//
246+
// Properties
247+
//
248248

249249
// The CLR allows configuration of these properties, separately from whether the OS supports IPv4/6. We
250250
// do not provide these config options, so SupportsIPvX === OSSupportsIPvX.
@@ -761,9 +761,10 @@ internal bool CanTryAddressFamily(AddressFamily family)
761761
{
762762
return (family == _addressFamily) || (family == AddressFamily.InterNetwork && IsDualMode);
763763
}
764-
#endregion
765764

766-
#region Public Methods
765+
//
766+
// Public Methods
767+
//
767768

768769
// Associates a socket with an end point.
769770
public void Bind(EndPoint localEP)
@@ -2116,43 +2117,14 @@ public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? req
21162117
public IAsyncResult BeginConnect(IPAddress[] addresses, int port, AsyncCallback? requestCallback, object? state) =>
21172118
TaskToApm.Begin(ConnectAsync(addresses, port), requestCallback, state);
21182119

2119-
public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state)
2120+
public void EndConnect(IAsyncResult asyncResult)
21202121
{
21212122
ThrowIfDisposed();
2122-
2123-
// Start context-flowing op. No need to lock - we don't use the context till the callback.
2124-
DisconnectOverlappedAsyncResult asyncResult = new DisconnectOverlappedAsyncResult(this, state, callback);
2125-
asyncResult.StartPostingAsyncOp(false);
2126-
2127-
// Post the disconnect.
2128-
DoBeginDisconnect(reuseSocket, asyncResult);
2129-
2130-
// Finish flowing (or call the callback), and return.
2131-
asyncResult.FinishPostingAsyncOp();
2132-
return asyncResult;
2123+
TaskToApm.End(asyncResult);
21332124
}
21342125

2135-
private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
2136-
{
2137-
SocketError errorCode = SocketError.Success;
2138-
2139-
errorCode = SocketPal.DisconnectAsync(this, _handle, reuseSocket, asyncResult);
2140-
2141-
if (errorCode == SocketError.Success)
2142-
{
2143-
SetToDisconnected();
2144-
_remoteEndPoint = null;
2145-
_localEndPoint = null;
2146-
}
2147-
2148-
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}");
2149-
2150-
// If the call failed, update our status and throw
2151-
if (!CheckErrorAndUpdateStatus(errorCode))
2152-
{
2153-
throw new SocketException((int)errorCode);
2154-
}
2155-
}
2126+
public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) =>
2127+
TaskToApmBeginWithSyncExceptions(DisconnectAsync(reuseSocket).AsTask(), callback, state);
21562128

21572129
public void Disconnect(bool reuseSocket)
21582130
{
@@ -2175,47 +2147,12 @@ public void Disconnect(bool reuseSocket)
21752147
_localEndPoint = null;
21762148
}
21772149

2178-
public void EndConnect(IAsyncResult asyncResult)
2150+
public void EndDisconnect(IAsyncResult asyncResult)
21792151
{
21802152
ThrowIfDisposed();
21812153
TaskToApm.End(asyncResult);
21822154
}
21832155

2184-
public void EndDisconnect(IAsyncResult asyncResult)
2185-
{
2186-
ThrowIfDisposed();
2187-
2188-
if (asyncResult == null)
2189-
{
2190-
throw new ArgumentNullException(nameof(asyncResult));
2191-
}
2192-
2193-
//get async result and check for errors
2194-
LazyAsyncResult? castedAsyncResult = asyncResult as LazyAsyncResult;
2195-
if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this)
2196-
{
2197-
throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult));
2198-
}
2199-
if (castedAsyncResult.EndCalled)
2200-
{
2201-
throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, nameof(EndDisconnect)));
2202-
}
2203-
2204-
//wait for completion if it hasn't occurred
2205-
castedAsyncResult.InternalWaitForCompletion();
2206-
castedAsyncResult.EndCalled = true;
2207-
2208-
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this);
2209-
2210-
//
2211-
// if the asynchronous native call failed asynchronously
2212-
// we'll throw a SocketException
2213-
//
2214-
if ((SocketError)castedAsyncResult.ErrorCode != SocketError.Success)
2215-
{
2216-
UpdateStatusAfterSocketErrorAndThrowException((SocketError)castedAsyncResult.ErrorCode);
2217-
}
2218-
}
22192156

22202157
public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state)
22212158
{
@@ -2668,7 +2605,10 @@ public void Shutdown(SocketShutdown how)
26682605
InternalSetBlocking(_willBlockInternal);
26692606
}
26702607

2671-
#region Async methods
2608+
//
2609+
// Async methods
2610+
//
2611+
26722612
public bool AcceptAsync(SocketAsyncEventArgs e)
26732613
{
26742614
ThrowIfDisposed();
@@ -2889,7 +2829,9 @@ public static void CancelConnectAsync(SocketAsyncEventArgs e)
28892829
e.CancelConnectAsync();
28902830
}
28912831

2892-
public bool DisconnectAsync(SocketAsyncEventArgs e)
2832+
public bool DisconnectAsync(SocketAsyncEventArgs e) => DisconnectAsync(e, default);
2833+
2834+
private bool DisconnectAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
28932835
{
28942836
// Throw if socket disposed
28952837
ThrowIfDisposed();
@@ -2904,7 +2846,7 @@ public bool DisconnectAsync(SocketAsyncEventArgs e)
29042846
SocketError socketError = SocketError.Success;
29052847
try
29062848
{
2907-
socketError = e.DoOperationDisconnect(this, _handle);
2849+
socketError = e.DoOperationDisconnect(this, _handle, cancellationToken);
29082850
}
29092851
catch
29102852
{
@@ -3155,10 +3097,10 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT
31553097

31563098
return socketError == SocketError.IOPending;
31573099
}
3158-
#endregion
3159-
#endregion
31603100

3161-
#region Internal and private properties
3101+
//
3102+
// Internal and private properties
3103+
//
31623104

31633105
private CacheSet Caches
31643106
{
@@ -3174,9 +3116,10 @@ private CacheSet Caches
31743116
}
31753117

31763118
internal bool Disposed => _disposed != 0;
3177-
#endregion
31783119

3179-
#region Internal and private methods
3120+
//
3121+
// Internal and private methods
3122+
//
31803123

31813124
internal static void GetIPProtocolInformation(AddressFamily addressFamily, Internals.SocketAddress socketAddress, out bool isIPv4, out bool isIPv6)
31823125
{
@@ -3889,6 +3832,16 @@ private static SocketError GetSocketErrorFromFaultedTask(Task t)
38893832
};
38903833
}
38913834

3892-
#endregion
3835+
// Helper to maintain existing behavior of Socket APM methods to throw synchronously from Begin*.
3836+
private static IAsyncResult TaskToApmBeginWithSyncExceptions(Task task, AsyncCallback? callback, object? state)
3837+
{
3838+
if (task.IsFaulted)
3839+
{
3840+
task.GetAwaiter().GetResult();
3841+
Debug.Fail("Task faulted but GetResult did not throw???");
3842+
}
3843+
3844+
return TaskToApm.Begin(task, callback, state);
3845+
}
38933846
}
38943847
}

src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle h
9393
return socketError;
9494
}
9595

96-
internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
96+
internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
9797
{
9898
SocketError socketError = SocketPal.Disconnect(socket, handle, _disconnectReuseSocket);
9999
FinishOperationSync(socketError, 0, SocketFlags.None);

src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,11 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle
364364
}
365365
}
366366

367-
internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
367+
internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
368368
{
369+
// Note: CancellationToken is ignored for now.
370+
// See https://github.com/dotnet/runtime/issues/51452
371+
369372
NativeOverlapped* overlapped = AllocateNativeOverlapped();
370373
try
371374
{
@@ -1188,6 +1191,7 @@ private unsafe SocketError FinishOperationConnect()
11881191
private void CompleteCore()
11891192
{
11901193
_strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially
1194+
11911195
if (_singleBufferHandleState != SingleBufferHandleState.None)
11921196
{
11931197
// If the state isn't None, then either it's Set, in which case there's state to cleanup,
@@ -1213,6 +1217,8 @@ void CompleteCoreSpin()
12131217
sw.SpinOnce();
12141218
}
12151219

1220+
Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);
1221+
12161222
// Remove any cancellation registration. First dispose the registration
12171223
// to ensure that cancellation will either never fine or will have completed
12181224
// firing before we continue. Only then can we safely null out the overlapped.
@@ -1223,6 +1229,8 @@ void CompleteCoreSpin()
12231229
}
12241230

12251231
// Release any GC handles.
1232+
Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);
1233+
12261234
if (_singleBufferHandleState == SingleBufferHandleState.Set)
12271235
{
12281236
_singleBufferHandleState = SingleBufferHandleState.None;

src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,13 +1976,6 @@ public static SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, Sa
19761976
return socketError;
19771977
}
19781978

1979-
internal static SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
1980-
{
1981-
SocketError socketError = Disconnect(socket, handle, reuseSocket);
1982-
asyncResult.PostCompletion(socketError);
1983-
return socketError;
1984-
}
1985-
19861979
internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket)
19871980
{
19881981
handle.SetToDisconnected();

0 commit comments

Comments
 (0)