Skip to content

Commit

Permalink
Raw SocketDescriptor variant of UnixDomainServerSocket::Accept
Browse files Browse the repository at this point in the history
The Mojo code on CrOS needs to accept inbound connections
on a unix domain socket, and then 'promote' the resulting
sockets to Mojo MessagePipes. This really requires access
to the underying file descriptor, so provide a mechanism
to accept a connection and get back a SocketDescriptor.

BUG=407782
TEST=UnixDomain*SocketTest
R=mmenke@chromium.org

Review URL: https://codereview.chromium.org/509133002

Cr-Commit-Position: refs/heads/master@{#293172}
  • Loading branch information
cmasone authored and Commit bot committed Sep 3, 2014
1 parent 292bdba commit ca100d5
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 38 deletions.
63 changes: 37 additions & 26 deletions net/socket/socket_libevent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ int SocketLibevent::AdoptConnectedSocket(SocketDescriptor socket,
return OK;
}

SocketDescriptor SocketLibevent::ReleaseConnectedSocket() {
StopWatchingAndCleanUp();
SocketDescriptor socket_fd = socket_fd_;
socket_fd_ = kInvalidSocket;
return socket_fd;
}

int SocketLibevent::Bind(const SockaddrStorage& address) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK_NE(kInvalidSocket, socket_fd_);
Expand Down Expand Up @@ -326,38 +333,13 @@ bool SocketLibevent::HasPeerAddress() const {
void SocketLibevent::Close() {
DCHECK(thread_checker_.CalledOnValidThread());

bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = read_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = write_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
StopWatchingAndCleanUp();

if (socket_fd_ != kInvalidSocket) {
if (IGNORE_EINTR(close(socket_fd_)) < 0)
PLOG(ERROR) << "close() returned an error, errno=" << errno;
socket_fd_ = kInvalidSocket;
}

if (!accept_callback_.is_null()) {
accept_socket_ = NULL;
accept_callback_.Reset();
}

if (!read_callback_.is_null()) {
read_buf_ = NULL;
read_buf_len_ = 0;
read_callback_.Reset();
}

if (!write_callback_.is_null()) {
write_buf_ = NULL;
write_buf_len_ = 0;
write_callback_.Reset();
}

waiting_connect_ = false;
peer_address_.reset();
}

void SocketLibevent::OnFileCanReadWithoutBlocking(int fd) {
Expand Down Expand Up @@ -468,4 +450,33 @@ void SocketLibevent::WriteCompleted() {
base::ResetAndReturn(&write_callback_).Run(rv);
}

void SocketLibevent::StopWatchingAndCleanUp() {
bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = read_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = write_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);

if (!accept_callback_.is_null()) {
accept_socket_ = NULL;
accept_callback_.Reset();
}

if (!read_callback_.is_null()) {
read_buf_ = NULL;
read_buf_len_ = 0;
read_callback_.Reset();
}

if (!write_callback_.is_null()) {
write_buf_ = NULL;
write_buf_len_ = 0;
write_callback_.Reset();
}

waiting_connect_ = false;
peer_address_.reset();
}

} // namespace net
7 changes: 6 additions & 1 deletion net/socket/socket_libevent.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class IPEndPoint;

// Socket class to provide asynchronous read/write operations on top of the
// posix socket api. It supports AF_INET, AF_INET6, and AF_UNIX addresses.
class SocketLibevent : public base::MessageLoopForIO::Watcher {
class NET_EXPORT_PRIVATE SocketLibevent
: public base::MessageLoopForIO::Watcher {
public:
SocketLibevent();
virtual ~SocketLibevent();
Expand All @@ -34,6 +35,8 @@ class SocketLibevent : public base::MessageLoopForIO::Watcher {
// Takes ownership of |socket|.
int AdoptConnectedSocket(SocketDescriptor socket,
const SockaddrStorage& peer_address);
// Releases ownership of |socket_fd_| to caller.
SocketDescriptor ReleaseConnectedSocket();

int Bind(const SockaddrStorage& address);

Expand Down Expand Up @@ -93,6 +96,8 @@ class SocketLibevent : public base::MessageLoopForIO::Watcher {
int DoWrite(IOBuffer* buf, int buf_len);
void WriteCompleted();

void StopWatchingAndCleanUp();

SocketDescriptor socket_fd_;

base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_;
Expand Down
9 changes: 9 additions & 0 deletions net/socket/unix_domain_client_socket_posix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,13 @@ int UnixDomainClientSocket::SetSendBufferSize(int32 size) {
return ERR_NOT_IMPLEMENTED;
}

SocketDescriptor UnixDomainClientSocket::ReleaseConnectedSocket() {
DCHECK(socket_);
DCHECK(socket_->IsConnected());

SocketDescriptor socket_fd = socket_->ReleaseConnectedSocket();
socket_.reset();
return socket_fd;
}

} // namespace net
6 changes: 6 additions & 0 deletions net/socket/unix_domain_client_socket_posix.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
#include "net/base/net_log.h"
#include "net/socket/socket_descriptor.h"
#include "net/socket/stream_socket.h"

namespace net {
Expand Down Expand Up @@ -63,6 +64,11 @@ class NET_EXPORT UnixDomainClientSocket : public StreamSocket {
virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
virtual int SetSendBufferSize(int32 size) OVERRIDE;

// Releases ownership of underlying SocketDescriptor to caller.
// Internal state is reset so that this object can be used again.
// Socket must be connected in order to release it.
SocketDescriptor ReleaseConnectedSocket();

private:
const std::string socket_path_;
const bool use_abstract_namespace_;
Expand Down
50 changes: 50 additions & 0 deletions net/socket/unix_domain_client_socket_posix_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
#include "base/files/file_path.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/scoped_ptr.h"
#include "base/posix/eintr_wrapper.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/socket_libevent.h"
#include "net/socket/unix_domain_server_socket_posix.h"
#include "testing/gtest/include/gtest/gtest.h"

Expand Down Expand Up @@ -148,6 +150,54 @@ TEST_F(UnixDomainClientSocketTest, Connect) {
EXPECT_TRUE(accepted_socket->IsConnected());
}

TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
const bool kUseAbstractNamespace = false;

UnixDomainServerSocket server_socket(CreateAuthCallback(true),
kUseAbstractNamespace);
EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));

SocketDescriptor accepted_socket_fd = kInvalidSocket;
TestCompletionCallback accept_callback;
EXPECT_EQ(ERR_IO_PENDING,
server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
accept_callback.callback()));
EXPECT_EQ(kInvalidSocket, accepted_socket_fd);

UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
EXPECT_FALSE(client_socket.IsConnected());

EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
EXPECT_TRUE(client_socket.IsConnected());
// Server has not yet been notified of the connection.
EXPECT_EQ(kInvalidSocket, accepted_socket_fd);

EXPECT_EQ(OK, accept_callback.WaitForResult());
EXPECT_NE(kInvalidSocket, accepted_socket_fd);

SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
EXPECT_NE(kInvalidSocket, client_socket_fd);

// Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
// to be sure it hasn't gotten accidentally closed.
SockaddrStorage addr;
ASSERT_TRUE(UnixDomainClientSocket::FillAddress(socket_path_, false, &addr));
scoped_ptr<SocketLibevent> adopter(new SocketLibevent);
adopter->AdoptConnectedSocket(client_socket_fd, addr);
UnixDomainClientSocket rewrapped_socket(adopter.Pass());
EXPECT_TRUE(rewrapped_socket.IsConnected());

// Try to read data.
const int kReadDataSize = 10;
scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
TestCompletionCallback read_callback;
EXPECT_EQ(ERR_IO_PENDING,
rewrapped_socket.Read(
read_buffer.get(), kReadDataSize, read_callback.callback()));

EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
}

TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
const bool kUseAbstractNamespace = true;

Expand Down
53 changes: 44 additions & 9 deletions net/socket/unix_domain_server_socket_posix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@

namespace net {

namespace {

// Intended for use as SetterCallbacks in Accept() helper methods.
void SetStreamSocket(scoped_ptr<StreamSocket>* socket,
scoped_ptr<SocketLibevent> accepted_socket) {
socket->reset(new UnixDomainClientSocket(accepted_socket.Pass()));
}

void SetSocketDescriptor(SocketDescriptor* socket,
scoped_ptr<SocketLibevent> accepted_socket) {
*socket = accepted_socket->ReleaseConnectedSocket();
}

} // anonymous namespace

UnixDomainServerSocket::UnixDomainServerSocket(
const AuthCallback& auth_callback,
bool use_abstract_namespace)
Expand Down Expand Up @@ -95,6 +110,23 @@ int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const {
int UnixDomainServerSocket::Accept(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback) {
DCHECK(socket);

SetterCallback setter_callback = base::Bind(&SetStreamSocket, socket);
return DoAccept(setter_callback, callback);
}

int UnixDomainServerSocket::AcceptSocketDescriptor(
SocketDescriptor* socket,
const CompletionCallback& callback) {
DCHECK(socket);

SetterCallback setter_callback = base::Bind(&SetSocketDescriptor, socket);
return DoAccept(setter_callback, callback);
}

int UnixDomainServerSocket::DoAccept(const SetterCallback& setter_callback,
const CompletionCallback& callback) {
DCHECK(!setter_callback.is_null());
DCHECK(!callback.is_null());
DCHECK(listen_socket_);
DCHECK(!accept_socket_);
Expand All @@ -103,38 +135,41 @@ int UnixDomainServerSocket::Accept(scoped_ptr<StreamSocket>* socket,
int rv = listen_socket_->Accept(
&accept_socket_,
base::Bind(&UnixDomainServerSocket::AcceptCompleted,
base::Unretained(this), socket, callback));
base::Unretained(this),
setter_callback,
callback));
if (rv != OK)
return rv;
if (AuthenticateAndGetStreamSocket(socket))
if (AuthenticateAndGetStreamSocket(setter_callback))
return OK;
// Accept another socket because authentication error should be transparent
// to the caller.
}
}

void UnixDomainServerSocket::AcceptCompleted(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback,
int rv) {
void UnixDomainServerSocket::AcceptCompleted(
const SetterCallback& setter_callback,
const CompletionCallback& callback,
int rv) {
if (rv != OK) {
callback.Run(rv);
return;
}

if (AuthenticateAndGetStreamSocket(socket)) {
if (AuthenticateAndGetStreamSocket(setter_callback)) {
callback.Run(OK);
return;
}

// Accept another socket because authentication error should be transparent
// to the caller.
rv = Accept(socket, callback);
rv = DoAccept(setter_callback, callback);
if (rv != ERR_IO_PENDING)
callback.Run(rv);
}

bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket(
scoped_ptr<StreamSocket>* socket) {
const SetterCallback& setter_callback) {
DCHECK(accept_socket_);

Credentials credentials;
Expand All @@ -144,7 +179,7 @@ bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket(
return false;
}

socket->reset(new UnixDomainClientSocket(accept_socket_.Pass()));
setter_callback.Run(accept_socket_.Pass());
return true;
}

Expand Down
16 changes: 14 additions & 2 deletions net/socket/unix_domain_server_socket_posix.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,23 @@ class NET_EXPORT UnixDomainServerSocket : public ServerSocket {
virtual int Accept(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback) OVERRIDE;

// Accepts an incoming connection on |listen_socket_|, but passes back
// a raw SocketDescriptor instead of a StreamSocket.
int AcceptSocketDescriptor(SocketDescriptor* socket_descriptor,
const CompletionCallback& callback);

private:
void AcceptCompleted(scoped_ptr<StreamSocket>* socket,
// A callback to wrap the setting of the out-parameter to Accept().
// This allows the internal machinery of that call to be implemented in
// a manner that's agnostic to the caller's desired output.
typedef base::Callback<void(scoped_ptr<SocketLibevent>)> SetterCallback;

int DoAccept(const SetterCallback& setter_callback,
const CompletionCallback& callback);
void AcceptCompleted(const SetterCallback& setter_callback,
const CompletionCallback& callback,
int rv);
bool AuthenticateAndGetStreamSocket(scoped_ptr<StreamSocket>* socket);
bool AuthenticateAndGetStreamSocket(const SetterCallback& setter_callback);

scoped_ptr<SocketLibevent> listen_socket_;
const AuthCallback auth_callback_;
Expand Down

0 comments on commit ca100d5

Please sign in to comment.