Skip to content

Commit

Permalink
Fix bssl client/server's error-handling.
Browse files Browse the repository at this point in the history
Rather than printing the SSL_ERROR_* constants, print the actual error.
This should be a bit more understandable. Debugging this also uncovered
some other issues on Windows:

- We were mixing up C runtime and Winsock errors, which are separate in
  Windows.

- The thread local implementation interferes with WSAGetLastError due to
  a quirk of TlsGetValue. This could affect other Windows consumers.
  (Chromium uses a custom BIO, so it isn't affected.)

- SocketSetNonBlocking also interferes with WSAGetLastError.

- Listen for FD_CLOSE along with FD_READ. Connection close does not
  signal FD_READ. (The select loop only barely works on Windows anyway
  due to issues with stdin and line buffering, but if we take stdin out
  of the equation, FD_CLOSE can be tested.)

Change-Id: If991259915acc96606a314fbe795fe6ea1e295e8
Reviewed-on: https://boringssl-review.googlesource.com/28125
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
  • Loading branch information
davidben authored and CQ bot account: commit-bot@chromium.org committed May 7, 2018
1 parent e30fac6 commit e7ca8a5
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 37 deletions.
23 changes: 23 additions & 0 deletions crypto/err/err_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@

#include "./internal.h"

#if defined(OPENSSL_WINDOWS)
OPENSSL_MSVC_PRAGMA(warning(push, 3))
#include <windows.h>
OPENSSL_MSVC_PRAGMA(warning(pop))
#else
#include <errno.h>
#endif


TEST(ErrTest, Overflow) {
for (unsigned i = 0; i < ERR_NUM_ERRORS*2; i++) {
Expand Down Expand Up @@ -212,3 +220,18 @@ TEST(ErrTest, SaveAndRestore) {
EXPECT_EQ(0u, ERR_get_error());
}
}

// Querying the error queue should not affect the OS error.
#if defined(OPENSSL_WINDOWS)
TEST(ErrTest, PreservesLastError) {
SetLastError(ERROR_INVALID_FUNCTION);
ERR_get_error();
EXPECT_EQ(ERROR_INVALID_FUNCTION, GetLastError());
}
#else
TEST(ErrTest, PreservesErrno) {
errno = EINVAL;
ERR_get_error();
EXPECT_EQ(EINVAL, errno);
}
#endif
22 changes: 20 additions & 2 deletions crypto/thread_win.c
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,31 @@ PIMAGE_TLS_CALLBACK p_thread_callback_boringssl = thread_local_destructor;

#endif // _WIN64

static void **get_thread_locals(void) {
// |TlsGetValue| clears the last error even on success, so that callers may
// distinguish it successfully returning NULL or failing. It is documented to
// never fail if the argument is a valid index from |TlsAlloc|, so we do not
// need to handle this.
//
// However, this error-mangling behavior interferes with the caller's use of
// |GetLastError|. In particular |SSL_get_error| queries the error queue to
// determine whether the caller should look at the OS's errors. To avoid
// destroying state, save and restore the Windows error.
//
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms686812(v=vs.85).aspx
DWORD last_error = GetLastError();
void **ret = TlsGetValue(g_thread_local_key);
SetLastError(last_error);
return ret;
}

void *CRYPTO_get_thread_local(thread_local_data_t index) {
CRYPTO_once(&g_thread_local_init_once, thread_local_init);
if (g_thread_local_failed) {
return NULL;
}

void **pointers = TlsGetValue(g_thread_local_key);
void **pointers = get_thread_locals();
if (pointers == NULL) {
return NULL;
}
Expand All @@ -211,7 +229,7 @@ int CRYPTO_set_thread_local(thread_local_data_t index, void *value,
return 0;
}

void **pointers = TlsGetValue(g_thread_local_key);
void **pointers = get_thread_locals();
if (pointers == NULL) {
pointers = OPENSSL_malloc(sizeof(void *) * NUM_OPENSSL_THREAD_LOCALS);
if (pointers == NULL) {
Expand Down
19 changes: 8 additions & 11 deletions tool/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ static int NewSessionCallback(SSL *ssl, SSL_SESSION *session) {
if (!PEM_write_bio_SSL_SESSION(session_out.get(), session) ||
BIO_flush(session_out.get()) <= 0) {
fprintf(stderr, "Error while saving session:\n");
ERR_print_errors_cb(PrintErrorCallback, stderr);
ERR_print_errors_fp(stderr);
return 0;
}
}
Expand Down Expand Up @@ -221,8 +221,7 @@ static bool WaitForSession(SSL *ssl, int sock) {
if (ssl_err == SSL_ERROR_WANT_READ) {
continue;
}
fprintf(stderr, "Error while reading: %d\n", ssl_err);
ERR_print_errors_cb(PrintErrorCallback, stderr);
PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret);
return false;
}
}
Expand Down Expand Up @@ -267,14 +266,14 @@ static bool DoConnection(SSL_CTX *ctx,
"rb"));
if (!in) {
fprintf(stderr, "Error reading session\n");
ERR_print_errors_cb(PrintErrorCallback, stderr);
ERR_print_errors_fp(stderr);
return false;
}
bssl::UniquePtr<SSL_SESSION> session(PEM_read_bio_SSL_SESSION(in.get(),
nullptr, nullptr, nullptr));
if (!session) {
fprintf(stderr, "Error reading session\n");
ERR_print_errors_cb(PrintErrorCallback, stderr);
ERR_print_errors_fp(stderr);
return false;
}
SSL_set_session(ssl.get(), session.get());
Expand All @@ -294,8 +293,7 @@ static bool DoConnection(SSL_CTX *ctx,
int ret = SSL_connect(ssl.get());
if (ret != 1) {
int ssl_err = SSL_get_error(ssl.get(), ret);
fprintf(stderr, "Error while connecting: %d\n", ssl_err);
ERR_print_errors_cb(PrintErrorCallback, stderr);
PrintSSLError(stderr, "Error while connecting", ssl_err, ret);
return false;
}

Expand All @@ -315,8 +313,7 @@ static bool DoConnection(SSL_CTX *ctx,
int ssl_ret = SSL_write(ssl.get(), early_data.data(), ed_size);
if (ssl_ret <= 0) {
int ssl_err = SSL_get_error(ssl.get(), ssl_ret);
fprintf(stderr, "Error while writing: %d\n", ssl_err);
ERR_print_errors_cb(PrintErrorCallback, stderr);
PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret);
return false;
} else if (ssl_ret != ed_size) {
fprintf(stderr, "Short write from SSL_write.\n");
Expand Down Expand Up @@ -500,7 +497,7 @@ bool Client(const std::vector<std::string> &args) {
if (!session_out) {
fprintf(stderr, "Error while opening %s:\n",
args_map["-session-out"].c_str());
ERR_print_errors_cb(PrintErrorCallback, stderr);
ERR_print_errors_fp(stderr);
return false;
}
}
Expand All @@ -513,7 +510,7 @@ bool Client(const std::vector<std::string> &args) {
if (!SSL_CTX_load_verify_locations(
ctx.get(), args_map["-root-certs"].c_str(), nullptr)) {
fprintf(stderr, "Failed to load root certificates.\n");
ERR_print_errors_cb(PrintErrorCallback, stderr);
ERR_print_errors_fp(stderr);
return false;
}
SSL_CTX_set_verify(ctx.get(), SSL_VERIFY_PEER, nullptr);
Expand Down
6 changes: 2 additions & 4 deletions tool/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ static bool HandleWWW(SSL *ssl) {
SSL_read(ssl, request + request_len, sizeof(request) - request_len);
if (ssl_ret <= 0) {
int ssl_err = SSL_get_error(ssl, ssl_ret);
fprintf(stderr, "Error while reading: %d\n", ssl_err);
ERR_print_errors_cb(PrintErrorCallback, stderr);
PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret);
return false;
}
request_len += static_cast<size_t>(ssl_ret);
Expand Down Expand Up @@ -342,8 +341,7 @@ bool Server(const std::vector<std::string> &args) {
int ret = SSL_accept(ssl.get());
if (ret != 1) {
int ssl_err = SSL_get_error(ssl.get(), ret);
fprintf(stderr, "Error while connecting: %d\n", ssl_err);
ERR_print_errors_cb(PrintErrorCallback, stderr);
PrintSSLError(stderr, "Error while connecting", ssl_err, ret);
result = false;
continue;
}
Expand Down
80 changes: 61 additions & 19 deletions tool/transport_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,33 @@ static void SplitHostPort(std::string *out_hostname, std::string *out_port,
}
}

static std::string GetLastSocketErrorString() {
#if defined(OPENSSL_WINDOWS)
int error = WSAGetLastError();
char *buffer;
DWORD len = FormatMessageA(
FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER, 0, error, 0,
reinterpret_cast<char *>(&buffer), 0, nullptr);
if (len == 0) {
char buf[256];
snprintf(buf, sizeof(buf), "unknown error (0x%x)", error);
return buf;
}
std::string ret(buffer, len);
LocalFree(buffer);
return ret;
#else
return strerror(errno);
#endif
}

static void PrintSocketError(const char *function) {
// On Windows, |perror| and |errno| are part of the C runtime, while sockets
// are separate, so we must print errors manually.
std::string error = GetLastSocketErrorString();
fprintf(stderr, "%s: %s\n", function, error.c_str());
}

// Connect sets |*out_sock| to be a socket connected to the destination given
// in |hostname_and_port|, which should be of the form "www.example.com:123".
// It returns true on success and false otherwise.
Expand Down Expand Up @@ -121,7 +148,7 @@ bool Connect(int *out_sock, const std::string &hostname_and_port) {
*out_sock =
socket(result->ai_family, result->ai_socktype, result->ai_protocol);
if (*out_sock < 0) {
perror("socket");
PrintSocketError("socket");
goto out;
}

Expand All @@ -145,7 +172,7 @@ bool Connect(int *out_sock, const std::string &hostname_and_port) {
}

if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) {
perror("connect");
PrintSocketError("connect");
goto out;
}
ok = true;
Expand Down Expand Up @@ -188,18 +215,18 @@ bool Listener::Init(const std::string &port) {

server_sock_ = socket(addr.sin6_family, SOCK_STREAM, 0);
if (server_sock_ < 0) {
perror("socket");
PrintSocketError("socket");
return false;
}

if (setsockopt(server_sock_, SOL_SOCKET, SO_REUSEADDR, (const char *)&enable,
sizeof(enable)) < 0) {
perror("setsockopt");
PrintSocketError("setsockopt");
return false;
}

if (bind(server_sock_, (struct sockaddr *)&addr, sizeof(addr)) != 0) {
perror("connect");
PrintSocketError("connect");
return false;
}

Expand Down Expand Up @@ -350,7 +377,7 @@ static bool SocketSelect(int sock, bool stdin_open, bool *socket_ready,
#else
WSAEVENT socket_handle = WSACreateEvent();
if (socket_handle == WSA_INVALID_EVENT ||
WSAEventSelect(sock, socket_handle, FD_READ) != 0) {
WSAEventSelect(sock, socket_handle, FD_READ | FD_CLOSE) != 0) {
WSACloseEvent(socket_handle);
return false;
}
Expand Down Expand Up @@ -379,11 +406,26 @@ static bool SocketSelect(int sock, bool stdin_open, bool *socket_ready,
#endif
}

// PrintErrorCallback is a callback function from OpenSSL's
// |ERR_print_errors_cb| that writes errors to a given |FILE*|.
int PrintErrorCallback(const char *str, size_t len, void *ctx) {
fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx));
return 1;
void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret) {
switch (ssl_err) {
case SSL_ERROR_SSL:
fprintf(file, "%s: %s\n", msg, ERR_reason_error_string(ERR_peek_error()));
break;
case SSL_ERROR_SYSCALL:
if (ret == 0) {
fprintf(file, "%s: peer closed connection\n", msg);
} else {
std::string error = GetLastSocketErrorString();
fprintf(file, "%s: %s\n", msg, error.c_str());
}
break;
case SSL_ERROR_ZERO_RETURN:
fprintf(file, "%s: received close_notify\n", msg);
break;
default:
fprintf(file, "%s: unknown error type (%d)\n", msg, ssl_err);
}
ERR_print_errors_fp(file);
}

bool TransferData(SSL *ssl, int sock) {
Expand Down Expand Up @@ -427,19 +469,20 @@ bool TransferData(SSL *ssl, int sock) {
}
#endif
int ssl_ret = SSL_write(ssl, buffer, n);
if (!SocketSetNonBlocking(sock, true)) {
return false;
}

if (ssl_ret <= 0) {
int ssl_err = SSL_get_error(ssl, ssl_ret);
fprintf(stderr, "Error while writing: %d\n", ssl_err);
ERR_print_errors_cb(PrintErrorCallback, stderr);
PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret);
return false;
} else if (ssl_ret != n) {
fprintf(stderr, "Short write from SSL_write.\n");
return false;
}

// Note we handle errors before restoring the non-blocking state. On
// Windows, |SocketSetNonBlocking| internally clears the last error.
if (!SocketSetNonBlocking(sock, true)) {
return false;
}
}

if (socket_ready) {
Expand All @@ -451,8 +494,7 @@ bool TransferData(SSL *ssl, int sock) {
if (ssl_err == SSL_ERROR_WANT_READ) {
continue;
}
fprintf(stderr, "Error while reading: %d\n", ssl_err);
ERR_print_errors_cb(PrintErrorCallback, stderr);
PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret);
return false;
} else if (ssl_ret == 0) {
return true;
Expand Down
5 changes: 4 additions & 1 deletion tool/transport_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ void PrintConnectionInfo(BIO *bio, const SSL *ssl);

bool SocketSetNonBlocking(int sock, bool is_non_blocking);

int PrintErrorCallback(const char *str, size_t len, void *ctx);
// PrintSSLError prints information about the most recent SSL error to stderr.
// |ssl_err| must be the output of |SSL_get_error| and the |SSL| object must be
// connected to socket from |Connect|.
void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret);

bool TransferData(SSL *ssl, int sock);

Expand Down

0 comments on commit e7ca8a5

Please sign in to comment.