Skip to content

Commit 1ca3726

Browse files
committed
Added socket::set_blocking(bool)
NOTE: May need different implementation on Windows(?)
1 parent 3b4f65b commit 1ca3726

File tree

5 files changed

+156
-89
lines changed

5 files changed

+156
-89
lines changed

include/sockpp/platform.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
#include <unistd.h>
9494
#include <sys/socket.h>
9595
#include <arpa/inet.h>
96+
#include <fcntl.h>
9697
#include <netdb.h>
9798
#include <signal.h>
9899
#include <errno.h>

include/sockpp/socket.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ class socket
337337
* @return The address of the remote peer, if this socket is connected.
338338
*/
339339
sock_address_any peer_address() const;
340+
/**
341+
* Puts the socket into nonblocking (false) or blocking (true) I/O mode.
342+
*/
343+
virtual bool set_blocking(bool blocking);
340344
/**
341345
* Gets the value of a socket option.
342346
*

include/sockpp/tls_socket.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,17 @@ namespace sockpp {
100100
// I/O primitives must be reimplemented in subclasses:
101101

102102
virtual ssize_t read(void *buf, size_t n) override = 0;
103-
104103
virtual ioresult read_r(void *buf, size_t n) override = 0;
105-
106104
virtual bool read_timeout(const std::chrono::microseconds& to) override = 0;
107-
108105
virtual ssize_t write(const void *buf, size_t n) override = 0;
109-
110106
virtual ioresult write_r(const void *buf, size_t n) override = 0;
111-
112107
virtual bool write_timeout(const std::chrono::microseconds& to) override = 0;
113108

114109
virtual ssize_t write(const std::vector<iovec> &ranges) override {
115110
return ranges.empty() ? 0 : write(ranges[0].iov_base, ranges[0].iov_len);
116111
}
117112

113+
virtual bool set_blocking(bool blocking) override = 0;
118114

119115
virtual void close() override {
120116
if (stream_) {

src/mbedtls_context.cpp

Lines changed: 132 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171

7272
// TODO: Better logging(?)
73-
#define log(FMT,...) fprintf(stderr, FMT "\n", ## __VA_ARGS__)
73+
#define log(FMT,...) fprintf(stderr, "TLS: " FMT "\n", ## __VA_ARGS__)
7474

7575

7676
namespace sockpp {
@@ -123,30 +123,32 @@ namespace sockpp {
123123
return;
124124
}
125125

126-
if (check_mbed_ret(mbedtls_ssl_setup(&ssl_, context_.ssl_config_.get()),
126+
if (check_mbed_setup(mbedtls_ssl_setup(&ssl_, context_.ssl_config_.get()),
127127
"mbedtls_ssl_setup"))
128128
return;
129-
if (!hostname.empty() && check_mbed_ret(mbedtls_ssl_set_hostname(&ssl_, hostname.c_str()),
129+
if (!hostname.empty() && check_mbed_setup(mbedtls_ssl_set_hostname(&ssl_, hostname.c_str()),
130130
"mbedtls_ssl_set_hostname"))
131131
return;
132132

133-
mbedtls_ssl_set_bio(&ssl_, this,
134-
[](void *ctx, const uint8_t *buf, size_t len) {
135-
return ((mbedtls_socket*)ctx)->bio_send(buf, len); },
136-
nullptr,
137-
[](void *ctx, uint8_t *buf, size_t len, uint32_t timeout) {
138-
return ((mbedtls_socket*)ctx)->bio_recv_timeout(buf, len, timeout); });
139-
open_ = true;
133+
#if defined(_WIN32)
134+
// Winsock does not allow us to tell if a socket is nonblocking, so assume it isn't
135+
bool blocking = true;
136+
#else
137+
int flags = fcntl(stream().handle(), F_GETFL, 0);
138+
bool blocking = (flags < 0 || (flags & O_NONBLOCK) == 0);
139+
#endif
140+
setup_bio(blocking);
140141

141142
// Run the TLS handshake:
142143
int status;
143144
do {
145+
open_ = true; // temporarily, so BIO methods won't fail
144146
status = mbedtls_ssl_handshake(&ssl_);
147+
open_ = false;
145148
} while (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE
146149
|| status == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS);
147-
if (check_mbed_ret(status, "mbedtls_ssl_handshake") != 0) {
150+
if (check_mbed_setup(status, "mbedtls_ssl_handshake") != 0)
148151
return;
149-
}
150152

151153
uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_);
152154
if (verify_flags != 0 && verify_flags != uint32_t(-1)
@@ -155,10 +157,25 @@ namespace sockpp {
155157
mbedtls_x509_crt_verify_info(vrfy_buf, sizeof( vrfy_buf ), "", verify_flags);
156158
log("Cert verify failed: %s", vrfy_buf );
157159
clear(MBEDTLS_ERR_X509_CERT_VERIFY_FAILED);
158-
open_ = false;
159160
reset();
160161
return;
161162
}
163+
open_ = true;
164+
}
165+
166+
167+
void setup_bio(bool blocking) {
168+
mbedtls_ssl_send_t *f_send = [](void *ctx, const uint8_t *buf, size_t len) {
169+
return ((mbedtls_socket*)ctx)->bio_send(buf, len); };
170+
mbedtls_ssl_recv_t *f_recv = nullptr;
171+
mbedtls_ssl_recv_timeout_t *f_recv_timeout = nullptr;
172+
if (blocking)
173+
f_recv_timeout = [](void *ctx, uint8_t *buf, size_t len, uint32_t timeout) {
174+
return ((mbedtls_socket*)ctx)->bio_recv_timeout(buf, len, timeout); };
175+
else
176+
f_recv = [](void *ctx, uint8_t *buf, size_t len) {
177+
return ((mbedtls_socket*)ctx)->bio_recv(buf, len); };
178+
mbedtls_ssl_set_bio(&ssl_, this, f_send, f_recv, f_recv_timeout);
162179
}
163180

164181

@@ -169,18 +186,18 @@ namespace sockpp {
169186
}
170187

171188

172-
int check_mbed_ret(int ret, const char *fn) {
173-
if (ret != 0) {
174-
log_mbed_ret(ret, fn);
175-
clear(ret); // sets last_error
176-
reset(); // marks me as closed/invalid
177-
stream().close();
189+
virtual void close() override {
190+
if (open_) {
191+
mbedtls_ssl_close_notify(&ssl_);
178192
open_ = false;
179193
}
180-
return ret;
194+
tls_socket::close();
181195
}
182196

183197

198+
// -------- certificate / trust API
199+
200+
184201
uint32_t peer_certificate_status() override {
185202
return mbedtls_ssl_get_verify_result(&ssl_);
186203
}
@@ -218,40 +235,6 @@ namespace sockpp {
218235
// -------- stream_socket I/O
219236

220237

221-
static int translate_mbed_err(int mbedErr) {
222-
switch (mbedErr) {
223-
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
224-
return 0;
225-
case MBEDTLS_ERR_SSL_WANT_READ:
226-
case MBEDTLS_ERR_SSL_WANT_WRITE:
227-
return EWOULDBLOCK;
228-
case MBEDTLS_ERR_NET_CONN_RESET:
229-
return ECONNRESET;
230-
case MBEDTLS_ERR_NET_RECV_FAILED:
231-
case MBEDTLS_ERR_NET_SEND_FAILED:
232-
return EIO;
233-
default:
234-
return mbedErr;
235-
}
236-
}
237-
238-
239-
inline ssize_t check_mbed_io(int mbedResult) {
240-
int result = translate_mbed_err(mbedResult);
241-
if (result < 0) {
242-
clear(result); // sets last_error
243-
result = -1;
244-
}
245-
return result;
246-
}
247-
248-
249-
static inline ioresult ioresult_from_mbed(int mbedResult) {
250-
mbedResult = translate_mbed_err(mbedResult);
251-
return mbedResult < 0 ? ioresult(0, mbedResult) : ioresult(mbedResult, 0);
252-
}
253-
254-
255238
ssize_t read(void *buf, size_t length) override {
256239
return check_mbed_io( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) );
257240
}
@@ -268,11 +251,15 @@ namespace sockpp {
268251

269252

270253
ssize_t write(const void *buf, size_t length) override {
254+
if (length == 0)
255+
return 0;
271256
return check_mbed_io( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
272257
}
273258

274259

275260
ioresult write_r(const void *buf, size_t length) override {
261+
if (length == 0)
262+
return {};
276263
return ioresult_from_mbed( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
277264
}
278265

@@ -282,62 +269,123 @@ namespace sockpp {
282269
}
283270

284271

285-
// -------- mbedTLS BIO callbacks
272+
bool set_blocking(bool blocking) override {
273+
bool ok = stream().set_blocking(blocking);
274+
if (ok)
275+
setup_bio(blocking);
276+
return ok;
277+
}
286278

287279

288-
template <bool reading>
289-
static int bio_return_value(ioresult result) {
290-
if (result.count >= 0)
291-
return (int)result.count;
292-
switch (result.error) {
293-
case EPIPE:
294-
case ECONNRESET:
295-
return MBEDTLS_ERR_NET_CONN_RESET;
296-
case EINTR:
297-
#if defined(EAGAIN)
298-
case EAGAIN:
299-
#endif
300-
#if defined(EWOULDBLOCK) && EWOULDBLOCK != EAGAIN
301-
case EWOULDBLOCK:
302-
#endif
303-
return reading ? MBEDTLS_ERR_SSL_WANT_READ
304-
: MBEDTLS_ERR_SSL_WANT_WRITE;
305-
default:
306-
return reading ? MBEDTLS_ERR_NET_RECV_FAILED
307-
: MBEDTLS_ERR_NET_SEND_FAILED;
308-
}
309-
}
280+
// -------- mbedTLS BIO callbacks
310281

311282

312283
int bio_send(const void* buf, size_t length) {
313284
if (!open_)
314-
return 0;
285+
return MBEDTLS_ERR_NET_CONN_RESET;
315286
return bio_return_value<false>(stream().write_r(buf, length));
316287
}
317288

318289

290+
int bio_recv(void* buf, size_t length) {
291+
if (!open_)
292+
return MBEDTLS_ERR_NET_CONN_RESET;
293+
return bio_return_value<true>(stream().read_r(buf, length));
294+
}
295+
296+
319297
int bio_recv_timeout(void* buf, size_t length, uint32_t timeout) {
320298
if (!open_)
321-
return 0;
299+
return MBEDTLS_ERR_NET_CONN_RESET;
322300
if (timeout > 0)
323301
stream().read_timeout(chrono::milliseconds(timeout));
324302

325-
int n = bio_return_value<true>(stream().read_r(buf, length));
303+
int n = bio_recv(buf, length);
326304

327305
if (timeout > 0)
328306
stream().read_timeout(chrono::hours(1000)); //FIXME: How do I turn off a timeout?
329307
return n;
330308
}
331309

332310

333-
virtual void close() override {
334-
if (open_) {
335-
mbedtls_ssl_close_notify(&ssl_);
311+
// -------- error handling
312+
313+
314+
// Translates mbedTLS error code to POSIX (errno)
315+
static int translate_mbed_err(int mbedErr) {
316+
switch (mbedErr) {
317+
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
318+
return 0;
319+
case MBEDTLS_ERR_SSL_WANT_READ:
320+
case MBEDTLS_ERR_SSL_WANT_WRITE:
321+
log(">>> mbedtls_socket returning EWOULDBLOCK");
322+
return EWOULDBLOCK;
323+
case MBEDTLS_ERR_NET_CONN_RESET:
324+
return ECONNRESET;
325+
case MBEDTLS_ERR_NET_RECV_FAILED:
326+
case MBEDTLS_ERR_NET_SEND_FAILED:
327+
return EIO;
328+
default:
329+
return mbedErr;
330+
}
331+
}
332+
333+
334+
// Handles an mbedTLS error return value during setup, closing me on error
335+
int check_mbed_setup(int ret, const char *fn) {
336+
if (ret != 0) {
337+
log_mbed_ret(ret, fn);
338+
clear(ret); // sets last_error
339+
reset(); // marks me as closed/invalid
340+
stream().close();
336341
open_ = false;
337342
}
338-
tls_socket::close();
343+
return ret;
339344
}
340345

346+
347+
// Handles an mbedTLS read/write return value, storing any error in last_error
348+
inline ssize_t check_mbed_io(int mbedResult) {
349+
int result = translate_mbed_err(mbedResult);
350+
if (result < 0) {
351+
clear(result); // sets last_error
352+
result = -1;
353+
}
354+
return result;
355+
}
356+
357+
358+
// Handles an mbedTLS read/write return value, converting it to an ioresult.
359+
static inline ioresult ioresult_from_mbed(int mbedResult) {
360+
mbedResult = translate_mbed_err(mbedResult);
361+
return mbedResult < 0 ? ioresult(0, mbedResult) : ioresult(mbedResult, 0);
362+
}
363+
364+
365+
// Translates ioresult to an mbedTLS error code to return from a BIO function.
366+
template <bool reading>
367+
static int bio_return_value(ioresult result) {
368+
if (result.count >= 0)
369+
return (int)result.count;
370+
switch (result.error) {
371+
case EPIPE:
372+
case ECONNRESET:
373+
return MBEDTLS_ERR_NET_CONN_RESET;
374+
case EINTR:
375+
case EWOULDBLOCK:
376+
#if defined(EAGAIN) && EAGAIN != EWOULDBLOCK // these are usually synonyms
377+
case EAGAIN:
378+
#endif
379+
log(">>> BIO returning MBEDTLS_ERR_SSL_WANT_%s", reading ?"READ":"WRITE");
380+
return reading ? MBEDTLS_ERR_SSL_WANT_READ
381+
: MBEDTLS_ERR_SSL_WANT_WRITE;
382+
default:
383+
return reading ? MBEDTLS_ERR_NET_RECV_FAILED
384+
: MBEDTLS_ERR_NET_SEND_FAILED;
385+
}
386+
}
387+
388+
341389
};
342390

343391

src/socket.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,24 @@ sock_address_any socket::peer_address() const
194194
return sock_address_any(addrStore, len);
195195
}
196196

197+
// --------------------------------------------------------------------------
198+
// Puts the socket into nonblocking or blocking I/O mode.
199+
200+
bool socket::set_blocking(bool blocking) {
201+
#if defined(_WIN32)
202+
u_long mode = !blocking;
203+
return check_ret_bool(::ioctlsocket(handle_, FIONBIO, &mode));
204+
#else
205+
int flags = check_ret(::fcntl(handle_, F_GETFL, 0));
206+
if (flags < 0)
207+
return false;
208+
int newFlags = blocking ? (flags & ~O_NONBLOCK) : (flags | O_NONBLOCK);
209+
if (newFlags == flags)
210+
return true;
211+
return check_ret_bool(::fcntl(handle_, F_SETFL, newFlags));
212+
#endif
213+
}
214+
197215
// --------------------------------------------------------------------------
198216

199217
bool socket::get_option(int level, int optname, void* optval, socklen_t* optlen) const

0 commit comments

Comments
 (0)