Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/support/errno_handling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file errno_handling.h
* \brief Common error number handling functions for socket.h and pipe.h
*/
#ifndef TVM_SUPPORT_ERRNO_HANDLING_H_
#define TVM_SUPPORT_ERRNO_HANDLING_H_
#include <errno.h>

#include "ssize.h"

namespace tvm {
namespace support {
/*!
* \brief Call a function and retry if an EINTR error is encountered.
*
* Socket operations can return EINTR when the interrupt handler
* is registered by the execution environment(e.g. python).
* We should retry if there is no KeyboardInterrupt recorded in
* the environment.
*
* \note This function is needed to avoid rare interrupt event
* in long running server code.
*
* \param func The function to retry.
* \return The return code returned by function f or error_value on retry failure.
*/
template <typename FuncType, typename GetErrorCodeFuncType>
inline ssize_t RetryCallOnEINTR(FuncType func, GetErrorCodeFuncType fgeterrorcode) {
ssize_t ret = func();
// common path
if (ret != -1) return ret;
// less common path
do {
if (fgeterrorcode() == EINTR) {
// Call into env check signals to see if there are
// environment specific(e.g. python) signal exceptions.
// This function will throw an exception if there is
// if the process received a signal that requires TVM to return immediately (e.g. SIGINT).
runtime::EnvCheckSignals();
} else {
// other errors
return ret;
}
ret = func();
} while (ret == -1);
return ret;
}
} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_ERRNO_HANDLING_H_
42 changes: 31 additions & 11 deletions src/support/pipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <cstdlib>
#include <cstring>
#endif
#include "errno_handling.h"

namespace tvm {
namespace support {
Expand All @@ -52,8 +53,21 @@ class Pipe : public dmlc::Stream {
#endif
/*! \brief destructor */
~Pipe() { Flush(); }

using Stream::Read;
using Stream::Write;

/*!
* \return last error of pipe operation
*/
static int GetLastErrorCode() {
#ifdef _WIN32
return GetLastError();
#else
return errno;
#endif
}

/*!
* \brief reads data from a file descriptor
* \param ptr pointer to a memory buffer
Expand All @@ -63,12 +77,15 @@ class Pipe : public dmlc::Stream {
size_t Read(void* ptr, size_t size) final {
if (size == 0) return 0;
#ifdef _WIN32
DWORD nread;
ICHECK(ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr))
<< "Read Error: " << GetLastError();
auto fread = [&]() {
DWORD nread;
if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr)) return -1;
return nread;
};
DWORD nread = static_cast<DWORD>(RetryCallOnEINTR(fread, GetLastErrorCode));
ICHECK_EQ(static_cast<size_t>(nread), size) << "Read Error: " << GetLastError();
#else
ssize_t nread;
nread = read(handle_, ptr, size);
ssize_t nread = RetryCallOnEINTR([&]() { return read(handle_, ptr, size); }, GetLastErrorCode);
ICHECK_GE(nread, 0) << "Write Error: " << strerror(errno);
#endif
return static_cast<size_t>(nread);
Expand All @@ -82,13 +99,16 @@ class Pipe : public dmlc::Stream {
void Write(const void* ptr, size_t size) final {
if (size == 0) return;
#ifdef _WIN32
DWORD nwrite;
ICHECK(WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr) &&
static_cast<size_t>(nwrite) == size)
<< "Write Error: " << GetLastError();
auto fwrite = [&]() {
DWORD nwrite;
if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr)) return -1;
return nwrite;
};
DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, GetLastErrorCode));
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << GetLastError();
#else
ssize_t nwrite;
nwrite = write(handle_, ptr, size);
ssize_t nwrite =
RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode);
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << strerror(errno);
#endif
}
Expand Down
65 changes: 17 additions & 48 deletions src/support/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
#endif
#else
#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
Expand All @@ -56,8 +55,9 @@
#include <unordered_map>
#include <vector>

#include "../support/ssize.h"
#include "../support/utils.h"
#include "errno_handling.h"
#include "ssize.h"
#include "utils.h"

#if defined(_WIN32)
static inline int poll(struct pollfd* pfd, int nfds, int timeout) {
Expand Down Expand Up @@ -310,7 +310,7 @@ class Socket {
/*!
* \return last error of socket operation
*/
static int GetLastError() {
static int GetLastErrorCode() {
#ifdef _WIN32
return WSAGetLastError();
#else
Expand All @@ -319,7 +319,7 @@ class Socket {
}
/*! \return whether last error was would block */
static bool LastErrorWouldBlock() {
int errsv = GetLastError();
int errsv = GetLastErrorCode();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
Expand Down Expand Up @@ -355,50 +355,14 @@ class Socket {
* \param msg The error message.
*/
static void Error(const char* msg) {
int errsv = GetLastError();
int errsv = GetLastErrorCode();
#ifdef _WIN32
LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
#else
LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv);
#endif
}

/*!
* \brief Call a function and retry if an EINTR error is encountered.
*
* Socket operations can return EINTR when the interrupt handler
* is registered by the execution environment(e.g. python).
* We should retry if there is no KeyboardInterrupt recorded in
* the environment.
*
* \note This function is needed to avoid rare interrupt event
* in long running server code.
*
* \param func The function to retry.
* \return The return code returned by function f or error_value on retry failure.
*/
template <typename FuncType>
ssize_t RetryCallOnEINTR(FuncType func) {
ssize_t ret = func();
// common path
if (ret != -1) return ret;
// less common path
do {
if (GetLastError() == EINTR) {
// Call into env check signals to see if there are
// environment specific(e.g. python) signal exceptions.
// This function will throw an exception if there is
// if the process received a signal that requires TVM to return immediately (e.g. SIGINT).
runtime::EnvCheckSignals();
} else {
// other errors
return ret;
}
ret = func();
} while (ret == -1);
return ret;
}

protected:
explicit Socket(SockType sockfd) : sockfd(sockfd) {}
};
Expand Down Expand Up @@ -445,7 +409,8 @@ class TCPSocket : public Socket {
* \return The accepted socket connection.
*/
TCPSocket Accept() {
SockType newfd = RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); });
SockType newfd =
RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); }, GetLastErrorCode);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
Expand All @@ -459,7 +424,8 @@ class TCPSocket : public Socket {
TCPSocket Accept(SockAddr* addr) {
socklen_t addrlen = sizeof(addr->addr);
SockType newfd = RetryCallOnEINTR(
[&]() { return accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); });
[&]() { return accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); },
GetLastErrorCode);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
Expand Down Expand Up @@ -500,7 +466,7 @@ class TCPSocket : public Socket {
ssize_t Send(const void* buf_, size_t len, int flag = 0) {
const char* buf = reinterpret_cast<const char*>(buf_);
return RetryCallOnEINTR(
[&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); });
[&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); }, GetLastErrorCode);
}
/*!
* \brief receive data using the socket
Expand All @@ -513,7 +479,8 @@ class TCPSocket : public Socket {
ssize_t Recv(void* buf_, size_t len, int flags = 0) {
char* buf = reinterpret_cast<char*>(buf_);
return RetryCallOnEINTR(
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); });
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); },
GetLastErrorCode);
}
/*!
* \brief perform block write that will attempt to send all data out
Expand All @@ -527,7 +494,8 @@ class TCPSocket : public Socket {
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = RetryCallOnEINTR(
[&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); });
[&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); },
GetLastErrorCode);
if (ret == -1) {
if (LastErrorWouldBlock()) return ndone;
Socket::Error("SendAll");
Expand All @@ -549,7 +517,8 @@ class TCPSocket : public Socket {
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = RetryCallOnEINTR(
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); });
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); },
GetLastErrorCode);
if (ret == -1) {
if (LastErrorWouldBlock()) {
LOG(FATAL) << "would block";
Expand Down