diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc index b5b06ccf763eb9..ff755c919534f5 100644 --- a/net/websockets/websocket_basic_handshake_stream.cc +++ b/net/websockets/websocket_basic_handshake_stream.cc @@ -8,9 +8,7 @@ #include #include #include -#include #include -#include #include "base/base64.h" #include "base/bind.h" @@ -179,8 +177,8 @@ WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( false /* http_09_on_non_default_ports_enabled */), connect_delegate_(connect_delegate), http_response_info_(nullptr), - requested_sub_protocols_(requested_sub_protocols), - requested_extensions_(requested_extensions), + requested_sub_protocols_(std::move(requested_sub_protocols)), + requested_extensions_(std::move(requested_extensions)), stream_request_(request), websocket_endpoint_lock_manager_(websocket_endpoint_lock_manager) { DCHECK(connect_delegate); @@ -188,6 +186,12 @@ WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( } WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() { + // Some members are "stolen" by RenewStreamForAuth() and should not be touched + // here. Particularly |connect_delegate_|, |stream_request_|, and + // |websocket_endpoint_lock_manager_|. + + // TODO(ricea): What's the right thing to do here if we renewed the stream for + // auth? Currently we record it as INCOMPLETE. RecordHandshakeResult(result_); } @@ -362,8 +366,24 @@ void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { } HttpStream* WebSocketBasicHandshakeStream::RenewStreamForAuth() { - // Return null because we don't support renewing the stream. - return nullptr; + if (!base::FeatureList::IsEnabled(kWebSocketHandshakeReuseConnection)) + return nullptr; + + DCHECK(IsResponseBodyComplete()); + DCHECK(!parser()->IsMoreDataBuffered()); + // The HttpStreamParser object still has a pointer to the connection. Just to + // be extra-sure it doesn't touch the connection again, delete it here rather + // than leaving it until the destructor is called. + state_.DeleteParser(); + + auto handshake_stream = std::make_unique( + state_.ReleaseConnection(), connect_delegate_, state_.using_proxy(), + std::move(requested_sub_protocols_), std::move(requested_extensions_), + stream_request_, websocket_endpoint_lock_manager_); + + stream_request_->OnBasicHandshakeStreamCreated(handshake_stream.get()); + + return handshake_stream.release(); } std::unique_ptr WebSocketBasicHandshakeStream::Upgrade() { diff --git a/net/websockets/websocket_basic_handshake_stream.h b/net/websockets/websocket_basic_handshake_stream.h index ff73bd4554afd9..27e6f03500265a 100644 --- a/net/websockets/websocket_basic_handshake_stream.h +++ b/net/websockets/websocket_basic_handshake_stream.h @@ -122,7 +122,7 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream final // Owned by another object. // |connect_delegate| will live during the lifetime of this object. - WebSocketStream::ConnectDelegate* connect_delegate_; + WebSocketStream::ConnectDelegate* const connect_delegate_; // This is stored in SendRequest() for use by ReadResponseHeaders(). HttpResponseInfo* http_response_info_; diff --git a/net/websockets/websocket_stream_test.cc b/net/websockets/websocket_stream_test.cc index 4afdd58205647b..1b972d39dcff11 100644 --- a/net/websockets/websocket_stream_test.cc +++ b/net/websockets/websocket_stream_test.cc @@ -18,6 +18,7 @@ #include "base/run_loop.h" #include "base/strings/stringprintf.h" #include "base/test/histogram_tester.h" +#include "base/test/scoped_feature_list.h" #include "base/timer/mock_timer.h" #include "base/timer/timer.h" #include "net/base/net_errors.h" @@ -104,6 +105,12 @@ class WebSocketStreamCreateTest : public TestWithParam, base::RunLoop().RunUntilIdle(); } + // Normally it's easier to use CreateAndConnectRawExpectations() instead. This + // method is only needed when multiple sockets are involved. + void AddRawExpectations(std::unique_ptr socket_data) { + url_request_context_host_.AddRawExpectations(std::move(socket_data)); + } + void AddSSLData() { auto ssl_data = std::make_unique(ASYNC, OK); ssl_data->ssl_info.cert = @@ -269,7 +276,7 @@ class WebSocketStreamCreateTest : public TestWithParam, auto socket_data = std::make_unique(reads_, writes_); socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK)); - url_request_context_host_.AddRawExpectations(std::move(socket_data)); + AddRawExpectations(std::move(socket_data)); // Send first request. This makes sure server's // spdy::SETTINGS_ENABLE_CONNECT_PROTOCOL advertisement is read. @@ -346,7 +353,7 @@ class WebSocketStreamCreateTest : public TestWithParam, std::unique_ptr socket_data) { ASSERT_EQ(BASIC_HANDSHAKE_STREAM, stream_type_); - url_request_context_host_.AddRawExpectations(std::move(socket_data)); + AddRawExpectations(std::move(socket_data)); CreateAndConnectStream(GURL(url), sub_protocols, Origin(), SiteForCookies(), additional_headers, std::move(timer_)); } @@ -465,10 +472,9 @@ class WebSocketStreamCreateBasicAuthTest : public WebSocketStreamCreateTest { void CreateAndConnectAuthHandshake(const std::string& url, const std::string& base64_user_pass, const std::string& response2) { - url_request_context_host_.AddRawExpectations( - helper_.BuildSocketData1(kUnauthorizedResponse)); + AddRawExpectations(helper_.BuildSocketData1(kUnauthorizedResponse)); - static const char request2format[] = + static constexpr char request2format[] = "GET / HTTP/1.1\r\n" "Host: www.example.org\r\n" "Connection: Upgrade\r\n" @@ -1428,7 +1434,7 @@ TEST_P(WebSocketStreamCreateTest, SelfSignedCertificateSuccess) { std::move(ssl_socket_data)); url_request_context_host_.AddSSLSocketDataProvider( std::make_unique(ASYNC, OK)); - url_request_context_host_.AddRawExpectations(BuildNullSocketData()); + AddRawExpectations(BuildNullSocketData()); CreateAndConnectStandard("wss://www.example.org/", NoSubProtocols(), {}, {}, {}); // WaitUntilConnectDone doesn't work in this case. @@ -1470,12 +1476,70 @@ TEST_P(WebSocketStreamCreateBasicAuthTest, FailureIncorrectPasswordInUrl) { EXPECT_TRUE(response_info_); } +TEST_P(WebSocketStreamCreateBasicAuthTest, SuccessfulConnectionReuse) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + WebSocketBasicHandshakeStream ::kWebSocketHandshakeReuseConnection); + + std::string request1 = + "GET / HTTP/1.1\r\n" + "Host: www.example.org\r\n" + "Connection: Keep-Alive, Upgrade\r\n" + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" + "Upgrade: websocket\r\n" + "Origin: http://www.example.org\r\n" + "Sec-WebSocket-Version: 13\r\n" + "User-Agent:\r\n" + "Accept-Encoding: gzip, deflate\r\n" + "Accept-Language: en-us,fr\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; " + "client_max_window_bits\r\n" + "\r\n"; + std::string response1 = kUnauthorizedResponse; + std::string request2 = + "GET / HTTP/1.1\r\n" + "Host: www.example.org\r\n" + "Connection: Keep-Alive, Upgrade\r\n" + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" + "Authorization: Basic Zm9vOmJhcg==\r\n" + "Upgrade: websocket\r\n" + "Origin: http://www.example.org\r\n" + "Sec-WebSocket-Version: 13\r\n" + "User-Agent:\r\n" + "Accept-Encoding: gzip, deflate\r\n" + "Accept-Language: en-us,fr\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; " + "client_max_window_bits\r\n" + "\r\n"; + std::string response2 = WebSocketStandardResponse(std::string()); + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, 0, request1.c_str()), + MockWrite(SYNCHRONOUS, 2, request2.c_str()), + }; + MockRead reads[3] = { + MockRead(SYNCHRONOUS, 1, response1.c_str()), + MockRead(SYNCHRONOUS, 3, response2.c_str()), + MockRead(SYNCHRONOUS, ERR_IO_PENDING, 4), + }; + CreateAndConnectRawExpectations("ws://foo:bar@www.example.org/", + NoSubProtocols(), HttpRequestHeaders(), + BuildSocketData(reads, writes)); + WaitUntilConnectDone(); + EXPECT_FALSE(has_failed()); + EXPECT_TRUE(stream_); + ASSERT_TRUE(response_info_); + EXPECT_EQ(101, response_info_->headers->response_code()); +} + // Digest auth has the same connection semantics as Basic auth, so we can // generally assume that whatever works for Basic auth will also work for // Digest. There's just one test here, to confirm that it works at all. TEST_P(WebSocketStreamCreateDigestAuthTest, DigestPasswordInUrl) { - url_request_context_host_.AddRawExpectations( - helper_.BuildSocketData1(kUnauthorizedResponse)); + AddRawExpectations(helper_.BuildSocketData1(kUnauthorizedResponse)); CreateAndConnectRawExpectations( "ws://FooBar:pass@www.example.org/", NoSubProtocols(),