Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HTTP/2 extended connect hang #80066

Merged
merged 4 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ public async Task<HeadersFrame> ReadRequestHeaderFrameAsync(bool expectEndOfStre
return (HeadersFrame)frame;
}

public async Task<Frame> ReadDataFrameAsync()
public async Task<DataFrame> ReadDataFrameAsync()
{
// Receive DATA frame for request.
Frame frame = await ReadFrameAsync(_timeout).ConfigureAwait(false);
Expand All @@ -412,7 +412,7 @@ public async Task<Frame> ReadDataFrameAsync()
}

Assert.Equal(FrameType.Data, frame.Type);
return frame;
return Assert.IsType<DataFrame>(frame);
}

private static (int bytesConsumed, int value) DecodeInteger(ReadOnlySpan<byte> headerBlock, byte prefixMask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,14 @@ private async Task ProcessIncomingFramesAsync()

// Process the initial SETTINGS frame. This will send an ACK.
ProcessSettingsFrame(frameHeader, initialFrame: true);

Debug.Assert(InitialSettingsReceived.Task.IsCompleted);
}
catch (IOException e)
catch (Exception e)
{
throw new IOException(SR.net_http_http2_connection_not_established, e);
e = new IOException(SR.net_http_http2_connection_not_established, e);
InitialSettingsReceived.TrySetException(e);
throw e;
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved
}

// Keep processing frames as they arrive.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2516,68 +2516,6 @@ public async Task PostAsyncDuplex_ServerSendsEndStream_Success()
}
}

[Fact]
public async Task ConnectAsync_ReadWriteWebSocketStream()
{
var clientMessage = new byte[] { 1, 2, 3 };
var serverMessage = new byte[] { 4, 5, 6, 7 };

using Http2LoopbackServer server = Http2LoopbackServer.CreateServer();
Http2LoopbackConnection connection = null;

Task serverTask = Task.Run(async () =>
{
connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });

// read request headers
(int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);

// send response headers
await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false);

// send reply
await connection.SendResponseDataAsync(streamId, serverMessage, endStream: false);

// send server EOS
await connection.SendResponseDataAsync(streamId, Array.Empty<byte>(), endStream: true);
});

StreamingHttpContent requestContent = new StreamingHttpContent();

using var handler = CreateSocketsHttpHandler(allowAllCertificates: true);
using HttpClient client = new HttpClient(handler);

HttpRequestMessage request = new(HttpMethod.Connect, server.Address);
request.Version = HttpVersion.Version20;
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
request.Headers.Protocol = "websocket";

// initiate request
var responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);

using HttpResponseMessage response = await responseTask.WaitAsync(TimeSpan.FromSeconds(10));

await serverTask.WaitAsync(TimeSpan.FromSeconds(60));

var responseStream = await response.Content.ReadAsStreamAsync();

// receive data
var readBuffer = new byte[10];
int bytesRead = await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10));
Assert.Equal(bytesRead, serverMessage.Length);
Assert.Equal(serverMessage, readBuffer[..bytesRead]);

await responseStream.WriteAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10));

// Send client's EOS
requestContent.CompleteStream();
// Receive server's EOS
Assert.Equal(0, await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10)));

Assert.NotNull(connection);
await connection.DisposeAsync();
}

[Fact]
public async Task PostAsyncDuplex_RequestContentException_ResetsStream()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO;
using System.Net.Test.Common;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;

namespace System.Net.Http.Functional.Tests
{
[ConditionalClass(typeof(SocketsHttpHandler), nameof(SocketsHttpHandler.IsSupported))]
public sealed class SocketsHttpHandler_Http2ExtendedConnect_Test : HttpClientHandlerTestBase
{
public SocketsHttpHandler_Http2ExtendedConnect_Test(ITestOutputHelper output) : base(output) { }

protected override Version UseVersion => HttpVersion.Version20;

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task Connect_ReadWriteResponseStream(bool useSsl)
{
byte[] clientMessage = new byte[] { 1, 2, 3 };
byte[] serverMessage = new byte[] { 4, 5, 6, 7 };

TaskCompletionSource clientCompleted = new(TaskCreationOptions.RunContinuationsAsynchronously);

await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri =>
{
using HttpClient client = CreateHttpClient();

HttpRequestMessage request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
request.Headers.Protocol = "foo";

using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);

using Stream responseStream = await response.Content.ReadAsStreamAsync();

await responseStream.WriteAsync(clientMessage);

byte[] readBuffer = new byte[serverMessage.Length];
await responseStream.ReadExactlyAsync(readBuffer);
Assert.Equal(serverMessage, readBuffer);

// Receive server's EOS
Assert.Equal(0, await responseStream.ReadAsync(readBuffer));

clientCompleted.SetResult();
},
async server =>
{
await using Http2LoopbackConnection connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });

(int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);

await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false);

DataFrame dataFrame = await connection.ReadDataFrameAsync();
Assert.Equal(clientMessage, dataFrame.Data.ToArray());

await connection.SendResponseDataAsync(streamId, serverMessage, endStream: true);

await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout);
}, options: new GenericLoopbackOptions { UseSsl = useSsl });
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task Connect_ServerDoesNotSupportExtendedConnect_ClientIncludesExceptionData(bool useSsl)
{
TaskCompletionSource clientCompleted = new(TaskCreationOptions.RunContinuationsAsynchronously);

await LoopbackServerFactory.CreateClientAndServerAsync(async uri =>
{
using HttpClient client = CreateHttpClient();

HttpRequestMessage request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
request.Headers.Protocol = "foo";

HttpRequestException ex = await Assert.ThrowsAsync<HttpRequestException>(() => client.SendAsync(request));

Assert.Equal(false, ex.Data["SETTINGS_ENABLE_CONNECT_PROTOCOL"]);

clientCompleted.SetResult();
},
async server =>
{
try
{
await server.AcceptConnectionAsync(async connection =>
{
await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout);
});
}
catch (Exception ex)
{
_output.WriteLine($"Ignoring exception {ex}");
}
}, options: new GenericLoopbackOptions { UseSsl = useSsl });
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task Connect_Http11Endpoint_Throws(bool useSsl)
{
using var server = new LoopbackServer(new LoopbackServer.Options
{
UseSsl = useSsl
});

await server.ListenAsync();

TaskCompletionSource clientCompleted = new(TaskCreationOptions.RunContinuationsAsynchronously);

Task serverTask = Task.Run(async () =>
{
try
{
await server.AcceptConnectionAsync(async connection =>
{
if (!useSsl)
{
byte[] http2GoAwayHttp11RequiredBytes = new byte[17] { 0, 0, 8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13 };

await connection.SendResponseAsync(http2GoAwayHttp11RequiredBytes);

await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout);
}
});
}
catch (Exception ex)
{
_output.WriteLine($"Ignoring exception {ex}");
}
});

Task clientTask = Task.Run(async () =>
{
using HttpClient client = CreateHttpClient();

HttpRequestMessage request = CreateRequest(HttpMethod.Connect, server.Address, UseVersion, exactVersion: true);
request.Headers.Protocol = "foo";

Exception ex = await Assert.ThrowsAnyAsync<Exception>(() => client.SendAsync(request));
clientCompleted.SetResult();

if (useSsl)
{
Assert.Equal(false, ex.Data["HTTP2_ENABLED"]);
}
});

await new[] { serverTask, clientTask }.WhenAllOrAnyFailed().WaitAsync(TestHelper.PassingTestTimeout);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
<Compile Include="HttpClientHandlerTest.AltSvc.cs" />
<Compile Include="SocketsHttpHandlerTest.Cancellation.cs" />
<Compile Include="SocketsHttpHandlerTest.Http1KeepAlive.cs" />
<Compile Include="SocketsHttpHandlerTest.Http2ExtendedConnect.cs" />
<Compile Include="SocketsHttpHandlerTest.Http2FlowControl.cs" />
<Compile Include="SocketsHttpHandlerTest.Http2KeepAlivePing.cs" />
<Compile Include="HttpClientHandlerTest.Connect.cs" />
Expand Down