Skip to content

Commit

Permalink
Fix a use-after-free bug of the failure message.
Browse files Browse the repository at this point in the history
If the WebSocket connection fails, it is possible for
WebSocketHandshakeStreamCreateHelper to have a pointer to a
WebSocketBasicHandshakeStream that has been deleted. So it is not safe
to store the failure message in WebSocketBasicHandshakeStream.

Instead, store it in StreamRequestImpl where it is guaranteed to stay
alive until the handshake completes.

BUG=379645

Review URL: https://codereview.chromium.org/368533002

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@281239 0039d316-1c4b-4281-b951-d872f2087c98
  • Loading branch information
ricea@chromium.org committed Jul 3, 2014
1 parent 1e042d7 commit 8aba017
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 57 deletions.
4 changes: 0 additions & 4 deletions net/http/http_stream_factory_impl_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ class MockWebSocketHandshakeStream : public WebSocketHandshakeStreamBase {
return scoped_ptr<WebSocketStream>();
}

virtual std::string GetFailureMessage() const OVERRIDE {
return std::string();
}

private:
const StreamType type_;
};
Expand Down
4 changes: 0 additions & 4 deletions net/url_request/url_request_http_job_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,6 @@ class FakeWebSocketHandshakeStream : public WebSocketHandshakeStreamBase {
return scoped_ptr<WebSocketStream>();
}

virtual std::string GetFailureMessage() const OVERRIDE {
return std::string();
}

private:
bool initialize_stream_was_called_;
};
Expand Down
51 changes: 29 additions & 22 deletions net/websockets/websocket_basic_handshake_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "base/basictypes.h"
#include "base/bind.h"
#include "base/containers/hash_tables.h"
#include "base/logging.h"
#include "base/metrics/histogram.h"
#include "base/metrics/sparse_histogram.h"
#include "base/stl_util.h"
Expand Down Expand Up @@ -345,12 +346,17 @@ WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
WebSocketStream::ConnectDelegate* connect_delegate,
bool using_proxy,
std::vector<std::string> requested_sub_protocols,
std::vector<std::string> requested_extensions)
std::vector<std::string> requested_extensions,
std::string* failure_message)
: state_(connection.release(), using_proxy),
connect_delegate_(connect_delegate),
http_response_info_(NULL),
requested_sub_protocols_(requested_sub_protocols),
requested_extensions_(requested_extensions) {}
requested_extensions_(requested_extensions),
failure_message_(failure_message) {
DCHECK(connect_delegate);
DCHECK(failure_message);
}

WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}

Expand Down Expand Up @@ -526,10 +532,6 @@ void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
handshake_challenge_for_testing_.reset(new std::string(key));
}

std::string WebSocketBasicHandshakeStream::GetFailureMessage() const {
return failure_message_;
}

void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
const CompletionCallback& callback,
int result) {
Expand Down Expand Up @@ -580,24 +582,24 @@ int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
// Reporting "Unexpected response code: 200" in this case is not
// helpful, so use a different error message.
if (headers->GetHttpVersion() == HttpVersion(0, 9)) {
failure_message_ =
"Error during WebSocket handshake: Invalid status line";
set_failure_message(
"Error during WebSocket handshake: Invalid status line");
} else {
failure_message_ = base::StringPrintf(
set_failure_message(base::StringPrintf(
"Error during WebSocket handshake: Unexpected response code: %d",
headers->response_code());
headers->response_code()));
}
OnFinishOpeningHandshake();
return ERR_INVALID_RESPONSE;
}
} else {
if (rv == ERR_EMPTY_RESPONSE) {
failure_message_ =
"Connection closed before receiving a handshake response";
set_failure_message(
"Connection closed before receiving a handshake response");
return rv;
}
failure_message_ =
std::string("Error during WebSocket handshake: ") + ErrorToString(rv);
set_failure_message(std::string("Error during WebSocket handshake: ") +
ErrorToString(rv));
OnFinishOpeningHandshake();
return rv;
}
Expand All @@ -606,24 +608,29 @@ int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
const HttpResponseHeaders* headers) {
extension_params_.reset(new WebSocketExtensionParams);
if (ValidateUpgrade(headers, &failure_message_) &&
ValidateSecWebSocketAccept(headers,
handshake_challenge_response_,
&failure_message_) &&
ValidateConnection(headers, &failure_message_) &&
std::string failure_message;
if (ValidateUpgrade(headers, &failure_message) &&
ValidateSecWebSocketAccept(
headers, handshake_challenge_response_, &failure_message) &&
ValidateConnection(headers, &failure_message) &&
ValidateSubProtocol(headers,
requested_sub_protocols_,
&sub_protocol_,
&failure_message_) &&
&failure_message) &&
ValidateExtensions(headers,
requested_extensions_,
&extensions_,
&failure_message_,
&failure_message,
extension_params_.get())) {
return OK;
}
failure_message_ = "Error during WebSocket handshake: " + failure_message_;
set_failure_message("Error during WebSocket handshake: " + failure_message);
return ERR_INVALID_RESPONSE;
}

void WebSocketBasicHandshakeStream::set_failure_message(
const std::string& failure_message) {
*failure_message_ = failure_message;
}

} // namespace net
10 changes: 6 additions & 4 deletions net/websockets/websocket_basic_handshake_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ struct WebSocketExtensionParams;
class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream
: public WebSocketHandshakeStreamBase {
public:
// |connect_delegate| and |failure_message| must out-live this object.
WebSocketBasicHandshakeStream(
scoped_ptr<ClientSocketHandle> connection,
WebSocketStream::ConnectDelegate* connect_delegate,
bool using_proxy,
std::vector<std::string> requested_sub_protocols,
std::vector<std::string> requested_extensions);
std::vector<std::string> requested_extensions,
std::string* failure_message);

virtual ~WebSocketBasicHandshakeStream();

Expand Down Expand Up @@ -75,8 +77,6 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream
// For tests only.
void SetWebSocketKeyForTesting(const std::string& key);

virtual std::string GetFailureMessage() const OVERRIDE;

private:
// A wrapper for the ReadResponseHeaders callback that checks whether or not
// the connection has been accepted.
Expand All @@ -94,6 +94,8 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream

HttpStreamParser* parser() const { return state_.parser(); }

void set_failure_message(const std::string& failure_message);

// The request URL.
GURL url_;

Expand Down Expand Up @@ -130,7 +132,7 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream
// to avoid including extension-related header files here.
scoped_ptr<WebSocketExtensionParams> extension_params_;

std::string failure_message_;
std::string* failure_message_;

DISALLOW_COPY_AND_ASSIGN(WebSocketBasicHandshakeStream);
};
Expand Down
5 changes: 4 additions & 1 deletion net/websockets/websocket_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ void WebSocketChannel::OnConnectSuccess(scoped_ptr<WebSocketStream> stream) {
void WebSocketChannel::OnConnectFailure(const std::string& message) {
DCHECK_EQ(CONNECTING, state_);

// Copy the message before we delete its owner.
std::string message_copy = message;

SetState(CLOSED);
stream_request_.reset();

Expand All @@ -581,7 +584,7 @@ void WebSocketChannel::OnConnectFailure(const std::string& message) {
// |this| has been deleted.
return;
}
AllowUnused(event_interface_->OnFailChannel(message));
AllowUnused(event_interface_->OnFailChannel(message_copy));
// |this| has been deleted.
}

Expand Down
4 changes: 0 additions & 4 deletions net/websockets/websocket_handshake_stream_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ class NET_EXPORT WebSocketHandshakeStreamBase : public HttpStreamBase {
// been called.
virtual scoped_ptr<WebSocketStream> Upgrade() = 0;

// Returns a string describing the connection failure information.
// Returns an empty string if there is no failure.
virtual std::string GetFailureMessage() const = 0;

protected:
// As with the destructor, this must be inline.
WebSocketHandshakeStreamBase() {}
Expand Down
18 changes: 11 additions & 7 deletions net/websockets/websocket_handshake_stream_create_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,29 @@ WebSocketHandshakeStreamCreateHelper::WebSocketHandshakeStreamCreateHelper(
const std::vector<std::string>& requested_subprotocols)
: requested_subprotocols_(requested_subprotocols),
stream_(NULL),
connect_delegate_(connect_delegate) {}
connect_delegate_(connect_delegate),
failure_message_(NULL) {
DCHECK(connect_delegate_);
}

WebSocketHandshakeStreamCreateHelper::~WebSocketHandshakeStreamCreateHelper() {}

WebSocketHandshakeStreamBase*
WebSocketHandshakeStreamCreateHelper::CreateBasicStream(
scoped_ptr<ClientSocketHandle> connection,
bool using_proxy) {
DCHECK(failure_message_) << "set_failure_message() must be called";
// The list of supported extensions and parameters is hard-coded.
// TODO(ricea): If more extensions are added, consider a more flexible
// method.
std::vector<std::string> extensions(
1, "permessage-deflate; client_max_window_bits");
return stream_ =
new WebSocketBasicHandshakeStream(connection.Pass(),
connect_delegate_,
using_proxy,
requested_subprotocols_,
extensions);
return stream_ = new WebSocketBasicHandshakeStream(connection.Pass(),
connect_delegate_,
using_proxy,
requested_subprotocols_,
extensions,
failure_message_);
}

// TODO(ricea): Create a WebSocketSpdyHandshakeStream. crbug.com/323852
Expand Down
14 changes: 13 additions & 1 deletion net/websockets/websocket_handshake_stream_create_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace net {
class NET_EXPORT_PRIVATE WebSocketHandshakeStreamCreateHelper
: public WebSocketHandshakeStreamBase::CreateHelper {
public:
// |connect_delegate| must out-live this object.
explicit WebSocketHandshakeStreamCreateHelper(
WebSocketStream::ConnectDelegate* connect_delegate,
const std::vector<std::string>& requested_subprotocols);
Expand All @@ -44,18 +45,29 @@ class NET_EXPORT_PRIVATE WebSocketHandshakeStreamCreateHelper
// Return the WebSocketHandshakeStreamBase object that we created. In the case
// where CreateBasicStream() was called more than once, returns the most
// recent stream, which will be the one on which the handshake succeeded.
// It is not safe to call this if the handshake failed.
WebSocketHandshakeStreamBase* stream() { return stream_; }

// Set a pointer to the std::string into which to write any failure messages
// that are encountered. This method must be called before CreateBasicStream()
// or CreateSpdyStream(). The |failure_message| pointer must remain valid as
// long as this object exists.
void set_failure_message(std::string* failure_message) {
failure_message_ = failure_message;
}

private:
const std::vector<std::string> requested_subprotocols_;

// This is owned by the caller of CreateBaseStream() or
// CreateSpdyStream(). Both the stream and this object will be destroyed
// during the destruction of the URLRequest object associated with the
// handshake.
// handshake. This is only guaranteed to be a valid pointer if the handshake
// succeeded.
WebSocketHandshakeStreamBase* stream_;

WebSocketStream::ConnectDelegate* connect_delegate_;
std::string* failure_message_;

DISALLOW_COPY_AND_ASSIGN(WebSocketHandshakeStreamCreateHelper);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test {
const std::string& extra_response_headers) {
WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_,
sub_protocols);
create_helper.set_failure_message(&failure_message_);

scoped_ptr<ClientSocketHandle> socket_handle =
socket_handle_factory_.CreateClientSocketHandle(
Expand Down Expand Up @@ -138,6 +139,7 @@ class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test {

MockClientSocketHandleFactory socket_handle_factory_;
TestConnectDelegate connect_delegate_;
std::string failure_message_;
};

// Confirm that the basic case works as expected.
Expand Down
15 changes: 8 additions & 7 deletions net/websockets/websocket_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class StreamRequestImpl : public WebSocketStreamRequest {
url_request_(url, DEFAULT_PRIORITY, delegate_.get(), context),
connect_delegate_(connect_delegate.Pass()),
create_helper_(create_helper.release()) {
create_helper_->set_failure_message(&failure_message_);
HttpRequestHeaders headers;
headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade);
Expand Down Expand Up @@ -117,25 +118,22 @@ class StreamRequestImpl : public WebSocketStreamRequest {
}

void ReportFailure() {
std::string failure_message;
if (create_helper_->stream()) {
failure_message = create_helper_->stream()->GetFailureMessage();
} else {
if (failure_message_.empty()) {
switch (url_request_.status().status()) {
case URLRequestStatus::SUCCESS:
case URLRequestStatus::IO_PENDING:
break;
case URLRequestStatus::CANCELED:
failure_message = "WebSocket opening handshake was canceled";
failure_message_ = "WebSocket opening handshake was canceled";
break;
case URLRequestStatus::FAILED:
failure_message =
failure_message_ =
std::string("Error in connection establishment: ") +
ErrorToString(url_request_.status().error());
break;
}
}
connect_delegate_->OnFailure(failure_message);
connect_delegate_->OnFailure(failure_message_);
}

WebSocketStream::ConnectDelegate* connect_delegate() const {
Expand All @@ -155,6 +153,9 @@ class StreamRequestImpl : public WebSocketStreamRequest {

// Owned by the URLRequest.
WebSocketHandshakeStreamCreateHelper* create_helper_;

// The failure message supplied by WebSocketBasicHandshakeStream, if any.
std::string failure_message_;
};

class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks {
Expand Down
7 changes: 4 additions & 3 deletions net/websockets/websocket_stream_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,12 @@ class WebSocketStreamCreateTest : public ::testing::Test {
scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate(
new TestConnectDelegate(this));
WebSocketStream::ConnectDelegate* delegate = connect_delegate.get();
scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper(
new DeterministicKeyWebSocketHandshakeStreamCreateHelper(
delegate, sub_protocols));
stream_request_ = ::net::CreateAndConnectStreamForTesting(
GURL(socket_url),
scoped_ptr<WebSocketHandshakeStreamCreateHelper>(
new DeterministicKeyWebSocketHandshakeStreamCreateHelper(
delegate, sub_protocols)),
create_helper.Pass(),
url::Origin(origin),
url_request_context_host_.GetURLRequestContext(),
BoundNetLog(),
Expand Down

0 comments on commit 8aba017

Please sign in to comment.