Skip to content

Commit

Permalink
[net] Use scoped_ptr<> consistently in ClientSocketFactory and relate…
Browse files Browse the repository at this point in the history
…d code

This will make it easier to modify ClientSocketFactory et al. to support
reprioritization. This also fixes a few latent memory leaks in tests.

Make SocketStream use a ClientSocketHandle instead of
just a StreamSocket.

Rename {set,release}_socket() to {Set,Pass}Socket().

BUG=166689
TBR=eroman@chromium.org, rsleevi@chromium.org, sergeyu@chromium.org

Review URL: https://codereview.chromium.org/22995002

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@217707 0039d316-1c4b-4281-b951-d872f2087c98
  • Loading branch information
akalin@chromium.org committed Aug 15, 2013
1 parent 582a857 commit 18ccfdb
Show file tree
Hide file tree
Showing 69 changed files with 727 additions and 605 deletions.
17 changes: 7 additions & 10 deletions chrome/browser/net/network_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,25 @@ bool NetworkStats::DoConnect(int result) {
return false;
}

net::DatagramClientSocket* udp_socket =
scoped_ptr<net::DatagramClientSocket> udp_socket =
socket_factory_->CreateDatagramClientSocket(
net::DatagramSocket::DEFAULT_BIND,
net::RandIntCallback(),
NULL,
net::NetLog::Source());
if (!udp_socket) {
TestPhaseComplete(SOCKET_CREATE_FAILED, net::ERR_INVALID_ARGUMENT);
return false;
}
DCHECK(!socket_.get());
socket_.reset(udp_socket);
DCHECK(udp_socket);
DCHECK(!socket_);
socket_ = udp_socket.Pass();

const net::IPEndPoint& endpoint = addresses_.front();
int rv = udp_socket->Connect(endpoint);
int rv = socket_->Connect(endpoint);
if (rv < 0) {
TestPhaseComplete(CONNECT_FAILED, rv);
return false;
}

udp_socket->SetSendBufferSize(kMaxUdpSendBufferSize);
udp_socket->SetReceiveBufferSize(kMaxUdpReceiveBufferSize);
socket_->SetSendBufferSize(kMaxUdpSendBufferSize);
socket_->SetReceiveBufferSize(kMaxUdpReceiveBufferSize);
return ConnectComplete(rv);
}

Expand Down
10 changes: 6 additions & 4 deletions content/browser/renderer_host/p2p/socket_host_tcp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ void P2PSocketHostTcpBase::OnConnected(int result) {
StartTls();
} else {
if (IsPseudoTlsClientSocket(type_)) {
socket_.reset(new jingle_glue::FakeSSLClientSocket(socket_.release()));
scoped_ptr<net::StreamSocket> transport_socket = socket_.Pass();
socket_.reset(
new jingle_glue::FakeSSLClientSocket(transport_socket.Pass()));
}

// If we are not doing TLS, we are ready to send data now.
Expand All @@ -155,7 +157,7 @@ void P2PSocketHostTcpBase::StartTls() {

scoped_ptr<net::ClientSocketHandle> socket_handle(
new net::ClientSocketHandle());
socket_handle->set_socket(socket_.release());
socket_handle->SetSocket(socket_.Pass());

net::SSLClientSocketContext context;
context.cert_verifier = url_context_->GetURLRequestContext()->cert_verifier();
Expand All @@ -171,8 +173,8 @@ void P2PSocketHostTcpBase::StartTls() {
net::ClientSocketFactory::GetDefaultFactory();
DCHECK(socket_factory);

socket_.reset(socket_factory->CreateSSLClientSocket(
socket_handle.release(), dest_host_port_pair, ssl_config, context));
socket_ = socket_factory->CreateSSLClientSocket(
socket_handle.Pass(), dest_host_port_pair, ssl_config, context);
int status = socket_->Connect(
base::Bind(&P2PSocketHostTcpBase::ProcessTlsConnectDone,
base::Unretained(this)));
Expand Down
8 changes: 4 additions & 4 deletions content/browser/renderer_host/pepper/pepper_tcp_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,16 @@ void PepperTCPSocket::SSLHandshake(
connection_state_ = SSL_HANDSHAKE_IN_PROGRESS;
// TODO(raymes,rsleevi): Use trusted/untrusted certificates when connecting.

net::ClientSocketHandle* handle = new net::ClientSocketHandle();
handle->set_socket(socket_.release());
scoped_ptr<net::ClientSocketHandle> handle(new net::ClientSocketHandle());
handle->SetSocket(socket_.Pass());
net::ClientSocketFactory* factory =
net::ClientSocketFactory::GetDefaultFactory();
net::HostPortPair host_port_pair(server_name, server_port);
net::SSLClientSocketContext ssl_context;
ssl_context.cert_verifier = manager_->GetCertVerifier();
ssl_context.transport_security_state = manager_->GetTransportSecurityState();
socket_.reset(factory->CreateSSLClientSocket(
handle, host_port_pair, manager_->ssl_config(), ssl_context));
socket_ = factory->CreateSSLClientSocket(
handle.Pass(), host_port_pair, manager_->ssl_config(), ssl_context);
if (!socket_) {
LOG(WARNING) << "Failed to create an SSL client socket.";
OnSSLHandshakeCompleted(net::ERR_UNEXPECTED);
Expand Down
10 changes: 5 additions & 5 deletions jingle/glue/chrome_async_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ bool ChromeAsyncSocket::Connect(const talk_base::SocketAddress& address) {

net::HostPortPair dest_host_port_pair(address.hostname(), address.port());

transport_socket_.reset(
transport_socket_ =
resolving_client_socket_factory_->CreateTransportClientSocket(
dest_host_port_pair));
dest_host_port_pair);
int status = transport_socket_->Connect(
base::Bind(&ChromeAsyncSocket::ProcessConnectDone,
weak_ptr_factory_.GetWeakPtr()));
Expand Down Expand Up @@ -404,10 +404,10 @@ bool ChromeAsyncSocket::StartTls(const std::string& domain_name) {
DCHECK(transport_socket_.get());
scoped_ptr<net::ClientSocketHandle> socket_handle(
new net::ClientSocketHandle());
socket_handle->set_socket(transport_socket_.release());
transport_socket_.reset(
socket_handle->SetSocket(transport_socket_.Pass());
transport_socket_ =
resolving_client_socket_factory_->CreateSSLClientSocket(
socket_handle.release(), net::HostPortPair(domain_name, 443)));
socket_handle.Pass(), net::HostPortPair(domain_name, 443));
int status = transport_socket_->Connect(
base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone,
weak_ptr_factory_.GetWeakPtr()));
Expand Down
8 changes: 4 additions & 4 deletions jingle/glue/chrome_async_socket_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,20 @@ class MockXmppClientSocketFactory : public ResolvingClientSocketFactory {
}

// ResolvingClientSocketFactory implementation.
virtual net::StreamSocket* CreateTransportClientSocket(
virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket(
const net::HostPortPair& host_and_port) OVERRIDE {
return mock_client_socket_factory_->CreateTransportClientSocket(
address_list_, NULL, net::NetLog::Source());
}

virtual net::SSLClientSocket* CreateSSLClientSocket(
net::ClientSocketHandle* transport_socket,
virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) OVERRIDE {
net::SSLClientSocketContext context;
context.cert_verifier = cert_verifier_.get();
context.transport_security_state = transport_security_state_.get();
return mock_client_socket_factory_->CreateSSLClientSocket(
transport_socket, host_and_port, ssl_config_, context);
transport_socket.Pass(), host_and_port, ssl_config_, context);
}

private:
Expand Down
4 changes: 2 additions & 2 deletions jingle/glue/fake_ssl_client_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ base::StringPiece FakeSSLClientSocket::GetSslServerHello() {
}

FakeSSLClientSocket::FakeSSLClientSocket(
net::StreamSocket* transport_socket)
: transport_socket_(transport_socket),
scoped_ptr<net::StreamSocket> transport_socket)
: transport_socket_(transport_socket.Pass()),
next_handshake_state_(STATE_NONE),
handshake_completed_(false),
write_buf_(NewDrainableIOBufferWithSize(arraysize(kSslClientHello))),
Expand Down
3 changes: 1 addition & 2 deletions jingle/glue/fake_ssl_client_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ namespace jingle_glue {

class FakeSSLClientSocket : public net::StreamSocket {
public:
// Takes ownership of |transport_socket|.
explicit FakeSSLClientSocket(net::StreamSocket* transport_socket);
explicit FakeSSLClientSocket(scoped_ptr<net::StreamSocket> transport_socket);

virtual ~FakeSSLClientSocket();

Expand Down
7 changes: 4 additions & 3 deletions jingle/glue/fake_ssl_client_socket_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class FakeSSLClientSocketTest : public testing::Test {

virtual ~FakeSSLClientSocketTest() {}

net::StreamSocket* MakeClientSocket() {
scoped_ptr<net::StreamSocket> MakeClientSocket() {
return mock_client_socket_factory_.CreateTransportClientSocket(
net::AddressList(), NULL, net::NetLog::Source());
}
Expand Down Expand Up @@ -269,7 +269,7 @@ class FakeSSLClientSocketTest : public testing::Test {
};

TEST_F(FakeSSLClientSocketTest, PassThroughMethods) {
MockClientSocket* mock_client_socket = new MockClientSocket();
scoped_ptr<MockClientSocket> mock_client_socket(new MockClientSocket());
const int kReceiveBufferSize = 10;
const int kSendBufferSize = 20;
net::IPEndPoint ip_endpoint(net::IPAddressNumber(net::kIPv4AddressSize), 80);
Expand All @@ -284,7 +284,8 @@ TEST_F(FakeSSLClientSocketTest, PassThroughMethods) {
EXPECT_CALL(*mock_client_socket, SetOmniboxSpeculation());

// Takes ownership of |mock_client_socket|.
FakeSSLClientSocket fake_ssl_client_socket(mock_client_socket);
FakeSSLClientSocket fake_ssl_client_socket(
mock_client_socket.PassAs<net::StreamSocket>());
fake_ssl_client_socket.SetReceiveBufferSize(kReceiveBufferSize);
fake_ssl_client_socket.SetSendBufferSize(kSendBufferSize);
EXPECT_EQ(kPeerAddress,
Expand Down
7 changes: 4 additions & 3 deletions jingle/glue/resolving_client_socket_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_
#define JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_

#include "base/memory/scoped_ptr.h"

namespace net {
class ClientSocketHandle;
Expand All @@ -23,11 +24,11 @@ class ResolvingClientSocketFactory {
public:
virtual ~ResolvingClientSocketFactory() { }
// Method to create a transport socket using a HostPortPair.
virtual net::StreamSocket* CreateTransportClientSocket(
virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket(
const net::HostPortPair& host_and_port) = 0;

virtual net::SSLClientSocket* CreateSSLClientSocket(
net::ClientSocketHandle* transport_socket,
virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) = 0;
};

Expand Down
26 changes: 16 additions & 10 deletions jingle/glue/xmpp_client_socket_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "jingle/glue/fake_ssl_client_socket.h"
#include "jingle/glue/proxy_resolving_client_socket.h"
#include "net/socket/client_socket_factory.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/ssl_client_socket.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/url_request_context_getter.h"
Expand All @@ -28,20 +29,25 @@ XmppClientSocketFactory::XmppClientSocketFactory(

XmppClientSocketFactory::~XmppClientSocketFactory() {}

net::StreamSocket* XmppClientSocketFactory::CreateTransportClientSocket(
scoped_ptr<net::StreamSocket>
XmppClientSocketFactory::CreateTransportClientSocket(
const net::HostPortPair& host_and_port) {
// TODO(akalin): Use socket pools.
net::StreamSocket* transport_socket = new ProxyResolvingClientSocket(
NULL,
request_context_getter_,
ssl_config_,
host_and_port);
scoped_ptr<net::StreamSocket> transport_socket(
new ProxyResolvingClientSocket(
NULL,
request_context_getter_,
ssl_config_,
host_and_port));
return (use_fake_ssl_client_socket_ ?
new FakeSSLClientSocket(transport_socket) : transport_socket);
scoped_ptr<net::StreamSocket>(
new FakeSSLClientSocket(transport_socket.Pass())) :
transport_socket.Pass());
}

net::SSLClientSocket* XmppClientSocketFactory::CreateSSLClientSocket(
net::ClientSocketHandle* transport_socket,
scoped_ptr<net::SSLClientSocket>
XmppClientSocketFactory::CreateSSLClientSocket(
scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) {
net::SSLClientSocketContext context;
context.cert_verifier =
Expand All @@ -52,7 +58,7 @@ net::SSLClientSocket* XmppClientSocketFactory::CreateSSLClientSocket(
// TODO(rkn): context.server_bound_cert_service is NULL because the
// ServerBoundCertService class is not thread safe.
return client_socket_factory_->CreateSSLClientSocket(
transport_socket, host_and_port, ssl_config_, context);
transport_socket.Pass(), host_and_port, ssl_config_, context);
}


Expand Down
6 changes: 3 additions & 3 deletions jingle/glue/xmpp_client_socket_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class XmppClientSocketFactory : public ResolvingClientSocketFactory {
virtual ~XmppClientSocketFactory();

// ResolvingClientSocketFactory implementation.
virtual net::StreamSocket* CreateTransportClientSocket(
virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket(
const net::HostPortPair& host_and_port) OVERRIDE;

virtual net::SSLClientSocket* CreateSSLClientSocket(
net::ClientSocketHandle* transport_socket,
virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) OVERRIDE;

private:
Expand Down
16 changes: 9 additions & 7 deletions net/dns/address_sorter_posix_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "net/base/net_util.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/client_socket_factory.h"
#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
#include "net/udp/datagram_client_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

Expand Down Expand Up @@ -90,27 +92,27 @@ class TestSocketFactory : public ClientSocketFactory {
TestSocketFactory() {}
virtual ~TestSocketFactory() {}

virtual DatagramClientSocket* CreateDatagramClientSocket(
virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType,
const RandIntCallback&,
NetLog*,
const NetLog::Source&) OVERRIDE {
return new TestUDPClientSocket(&mapping_);
return scoped_ptr<DatagramClientSocket>(new TestUDPClientSocket(&mapping_));
}
virtual StreamSocket* CreateTransportClientSocket(
virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList&,
NetLog*,
const NetLog::Source&) OVERRIDE {
NOTIMPLEMENTED();
return NULL;
return scoped_ptr<StreamSocket>();
}
virtual SSLClientSocket* CreateSSLClientSocket(
ClientSocketHandle*,
virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<ClientSocketHandle>,
const HostPortPair&,
const SSLConfig&,
const SSLClientSocketContext&) OVERRIDE {
NOTIMPLEMENTED();
return NULL;
return scoped_ptr<SSLClientSocket>();
}
virtual void ClearSSLSessionCache() OVERRIDE {
NOTIMPLEMENTED();
Expand Down
24 changes: 14 additions & 10 deletions net/dns/dns_session_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "net/dns/dns_protocol.h"
#include "net/dns/dns_socket_pool.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace net {
Expand All @@ -24,26 +26,26 @@ class TestClientSocketFactory : public ClientSocketFactory {
public:
virtual ~TestClientSocketFactory();

virtual DatagramClientSocket* CreateDatagramClientSocket(
virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
const net::NetLog::Source& source) OVERRIDE;

virtual StreamSocket* CreateTransportClientSocket(
virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog*, const NetLog::Source&) OVERRIDE {
NOTIMPLEMENTED();
return NULL;
return scoped_ptr<StreamSocket>();
}

virtual SSLClientSocket* CreateSSLClientSocket(
ClientSocketHandle* transport_socket,
virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
NOTIMPLEMENTED();
return NULL;
return scoped_ptr<SSLClientSocket>();
}

virtual void ClearSSLSessionCache() OVERRIDE {
Expand Down Expand Up @@ -179,7 +181,8 @@ bool DnsSessionTest::ExpectEvent(const PoolEvent& expected) {
return true;
}

DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket(
scoped_ptr<DatagramClientSocket>
TestClientSocketFactory::CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
Expand All @@ -188,9 +191,10 @@ DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket(
// simplest SocketDataProvider with no data supplied.
SocketDataProvider* data_provider = new StaticSocketDataProvider();
data_providers_.push_back(data_provider);
MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log);
data_provider->set_socket(socket);
return socket;
scoped_ptr<MockUDPClientSocket> socket(
new MockUDPClientSocket(data_provider, net_log));
data_provider->set_socket(socket.get());
return socket.PassAs<DatagramClientSocket>();
}

TestClientSocketFactory::~TestClientSocketFactory() {
Expand Down
4 changes: 2 additions & 2 deletions net/dns/dns_socket_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
scoped_ptr<DatagramClientSocket> socket;

NetLog::Source no_source;
socket.reset(socket_factory_->CreateDatagramClientSocket(
kBindType, base::Bind(&base::RandInt), net_log_, no_source));
socket = socket_factory_->CreateDatagramClientSocket(
kBindType, base::Bind(&base::RandInt), net_log_, no_source);

if (socket.get()) {
int rv = socket->Connect((*nameservers_)[server_index]);
Expand Down
Loading

0 comments on commit 18ccfdb

Please sign in to comment.