Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,53 @@ ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept {
#endif
}

#ifndef INVALID_SOCKET
const SOCKET INVALID_SOCKET = -1;
#endif

void CloseSocket(SOCKET socket) {
if (socket == INVALID_SOCKET)
return;

#if defined(_win_)
closesocket(socket);
#else
close(socket);
#endif
}

struct SocketRAIIWrapper {
SOCKET socket = INVALID_SOCKET;

~SocketRAIIWrapper() {
CloseSocket(socket);
}

SOCKET operator*() const {
return socket;
}

SOCKET release() {
auto result = socket;
socket = INVALID_SOCKET;

return result;
}
};

SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) {
int last_err = 0;
for (auto res = addr.Info(); res != nullptr; res = res->ai_next) {
SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol));
SocketRAIIWrapper s{socket(res->ai_family, res->ai_socktype, res->ai_protocol)};

if (s == -1) {
if (*s == INVALID_SOCKET) {
continue;
}

SetNonBlock(s, true);
SetTimeout(s, timeout_params);
SetNonBlock(*s, true);
SetTimeout(*s, timeout_params);

if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) {
if (connect(*s, res->ai_addr, (int)res->ai_addrlen) != 0) {
int err = getSocketErrorCode();
if (
err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK
Expand All @@ -165,7 +199,7 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
#endif
) {
pollfd fd;
fd.fd = s;
fd.fd = *s;
fd.events = POLLOUT;
fd.revents = 0;
ssize_t rval = Poll(&fd, 1, 5000);
Expand All @@ -175,18 +209,18 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
}
if (rval > 0) {
socklen_t len = sizeof(err);
getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);
getsockopt(*s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);

if (!err) {
SetNonBlock(s, false);
return s;
SetNonBlock(*s, false);
return s.release();
}
last_err = err;
}
}
} else {
SetNonBlock(s, false);
return s;
SetNonBlock(*s, false);
return s.release();
}
}
if (last_err > 0) {
Expand Down Expand Up @@ -265,15 +299,15 @@ Socket::Socket(const NetworkAddress & addr)
Socket::Socket(Socket&& other) noexcept
: handle_(other.handle_)
{
other.handle_ = -1;
other.handle_ = INVALID_SOCKET;
}

Socket& Socket::operator=(Socket&& other) noexcept {
if (this != &other) {
Close();

handle_ = other.handle_;
other.handle_ = -1;
other.handle_ = INVALID_SOCKET;
}

return *this;
Expand All @@ -284,14 +318,8 @@ Socket::~Socket() {
}

void Socket::Close() {
if (handle_ != -1) {
#if defined(_win_)
closesocket(handle_);
#else
close(handle_);
#endif
handle_ = -1;
}
CloseSocket(handle_);
handle_ = INVALID_SOCKET;
}

void Socket::SetTcpKeepAlive(int idle, int intvl, int cnt) noexcept {
Expand Down