7070
7171
7272//  TODO: Better logging(?)
73- #define  log (FMT,...) fprintf(stderr, FMT " \n " 
73+ #define  log (FMT,...) fprintf(stderr, " TLS:  "   FMT " \n " 
7474
7575
7676namespace  sockpp  {
@@ -123,30 +123,32 @@ namespace sockpp {
123123                return ;
124124            }
125125
126-             if  (check_mbed_ret (mbedtls_ssl_setup (&ssl_, context_.ssl_config_ .get ()),
126+             if  (check_mbed_setup (mbedtls_ssl_setup (&ssl_, context_.ssl_config_ .get ()),
127127                               " mbedtls_ssl_setup" 
128128                return ;
129-             if  (!hostname.empty () && check_mbed_ret (mbedtls_ssl_set_hostname (&ssl_, hostname.c_str ()),
129+             if  (!hostname.empty () && check_mbed_setup (mbedtls_ssl_set_hostname (&ssl_, hostname.c_str ()),
130130                                                    " mbedtls_ssl_set_hostname" 
131131                return ;
132132
133-             mbedtls_ssl_set_bio (&ssl_, this ,
134-                                 [](void  *ctx, const  uint8_t  *buf, size_t  len) {
135-                                     return  ((mbedtls_socket*)ctx)->bio_send (buf, len); },
136-                                 nullptr ,
137-                                 [](void  *ctx, uint8_t  *buf, size_t  len, uint32_t  timeout) {
138-                                     return  ((mbedtls_socket*)ctx)->bio_recv_timeout (buf, len, timeout); });
139-             open_ = true ;
133+ #if  defined(_WIN32)
134+             //  Winsock does not allow us to tell if a socket is nonblocking, so assume it isn't
135+             bool  blocking = true ;
136+ #else 
137+             int  flags = fcntl (stream ().handle (), F_GETFL, 0 );
138+             bool  blocking = (flags < 0  || (flags & O_NONBLOCK) == 0 );
139+ #endif 
140+             setup_bio (blocking);
140141
141142            //  Run the TLS handshake:
142143            int  status;
143144            do  {
145+                 open_ = true ;  //  temporarily, so BIO methods won't fail
144146                status = mbedtls_ssl_handshake (&ssl_);
147+                 open_ = false ;
145148            } while  (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE
146149                            || status == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS);
147-             if  (check_mbed_ret (status, " mbedtls_ssl_handshake" 0 ) { 
150+             if  (check_mbed_setup (status, " mbedtls_ssl_handshake" 0 )
148151                return ;
149-             }
150152
151153            uint32_t  verify_flags = mbedtls_ssl_get_verify_result (&ssl_);
152154            if  (verify_flags != 0  && verify_flags != uint32_t (-1 )
@@ -155,10 +157,25 @@ namespace sockpp {
155157                mbedtls_x509_crt_verify_info (vrfy_buf, sizeof ( vrfy_buf ), " " 
156158                log (" Cert verify failed: %s" 
157159                clear (MBEDTLS_ERR_X509_CERT_VERIFY_FAILED);
158-                 open_ = false ;
159160                reset ();
160161                return ;
161162            }
163+             open_ = true ;
164+         }
165+ 
166+ 
167+         void  setup_bio (bool  blocking) {
168+             mbedtls_ssl_send_t  *f_send = [](void  *ctx, const  uint8_t  *buf, size_t  len) {
169+                 return  ((mbedtls_socket*)ctx)->bio_send (buf, len); };
170+             mbedtls_ssl_recv_t  *f_recv = nullptr ;
171+             mbedtls_ssl_recv_timeout_t  *f_recv_timeout = nullptr ;
172+             if  (blocking)
173+                 f_recv_timeout = [](void  *ctx, uint8_t  *buf, size_t  len, uint32_t  timeout) {
174+                     return  ((mbedtls_socket*)ctx)->bio_recv_timeout (buf, len, timeout); };
175+             else 
176+                 f_recv = [](void  *ctx, uint8_t  *buf, size_t  len) {
177+                     return  ((mbedtls_socket*)ctx)->bio_recv (buf, len); };
178+             mbedtls_ssl_set_bio (&ssl_, this , f_send, f_recv, f_recv_timeout);
162179        }
163180
164181
@@ -169,18 +186,18 @@ namespace sockpp {
169186        }
170187
171188
172-         int  check_mbed_ret (int  ret, const  char  *fn) {
173-             if  (ret != 0 ) {
174-                 log_mbed_ret (ret, fn);
175-                 clear (ret); //  sets last_error
176-                 reset (); //  marks me as closed/invalid
177-                 stream ().close ();
189+         virtual  void  close () override  {
190+             if  (open_) {
191+                 mbedtls_ssl_close_notify (&ssl_);
178192                open_ = false ;
179193            }
180-             return  ret ;
194+             tls_socket::close () ;
181195        }
182196
183197
198+         //  -------- certificate / trust API
199+ 
200+ 
184201        uint32_t  peer_certificate_status () override  {
185202            return  mbedtls_ssl_get_verify_result (&ssl_);
186203        }
@@ -218,40 +235,6 @@ namespace sockpp {
218235        //  -------- stream_socket I/O
219236
220237
221-         static  int  translate_mbed_err (int  mbedErr) {
222-             switch  (mbedErr) {
223-                 case  MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
224-                     return  0 ;
225-                 case  MBEDTLS_ERR_SSL_WANT_READ:
226-                 case  MBEDTLS_ERR_SSL_WANT_WRITE:
227-                     return  EWOULDBLOCK;
228-                 case  MBEDTLS_ERR_NET_CONN_RESET:
229-                     return  ECONNRESET;
230-                 case  MBEDTLS_ERR_NET_RECV_FAILED:
231-                 case  MBEDTLS_ERR_NET_SEND_FAILED:
232-                     return  EIO;
233-                 default :
234-                     return  mbedErr;
235-             }
236-         }
237- 
238- 
239-         inline  ssize_t  check_mbed_io (int  mbedResult) {
240-             int  result = translate_mbed_err (mbedResult);
241-             if  (result < 0 ) {
242-                 clear (result);     //  sets last_error
243-                 result = -1 ;
244-             }
245-             return  result;
246-         }
247- 
248- 
249-         static  inline  ioresult ioresult_from_mbed (int  mbedResult) {
250-             mbedResult = translate_mbed_err (mbedResult);
251-             return  mbedResult < 0  ? ioresult (0 , mbedResult) : ioresult (mbedResult, 0 );
252-         }
253- 
254- 
255238        ssize_t  read (void  *buf, size_t  length) override  {
256239            return  check_mbed_io ( mbedtls_ssl_read (&ssl_, (uint8_t *)buf, length) );
257240        }
@@ -268,11 +251,15 @@ namespace sockpp {
268251
269252
270253        ssize_t  write (const  void  *buf, size_t  length) override  {
254+             if  (length == 0 )
255+                 return  0 ;
271256            return  check_mbed_io ( mbedtls_ssl_write (&ssl_, (const  uint8_t *)buf, length) );
272257        }
273258
274259
275260        ioresult write_r (const  void  *buf, size_t  length) override  {
261+             if  (length == 0 )
262+                 return  {};
276263            return  ioresult_from_mbed ( mbedtls_ssl_write (&ssl_, (const  uint8_t *)buf, length) );
277264        }
278265
@@ -282,62 +269,123 @@ namespace sockpp {
282269        }
283270
284271
285-         //  -------- mbedTLS BIO callbacks
272+         bool  set_blocking (bool  blocking) override  {
273+             bool  ok = stream ().set_blocking (blocking);
274+             if  (ok)
275+                 setup_bio (blocking);
276+             return  ok;
277+         }
286278
287279
288-         template  <bool  reading>
289-         static  int  bio_return_value (ioresult result) {
290-             if  (result.count  >= 0 )
291-                 return  (int )result.count ;
292-             switch  (result.error ) {
293-                 case  EPIPE:
294-                 case  ECONNRESET:
295-                     return  MBEDTLS_ERR_NET_CONN_RESET;
296-                 case  EINTR:
297- #if  defined(EAGAIN)
298-                 case  EAGAIN:
299- #endif 
300- #if  defined(EWOULDBLOCK) && EWOULDBLOCK != EAGAIN
301-                 case  EWOULDBLOCK:
302- #endif 
303-                     return  reading ? MBEDTLS_ERR_SSL_WANT_READ
304-                                    : MBEDTLS_ERR_SSL_WANT_WRITE;
305-                 default :
306-                     return  reading ? MBEDTLS_ERR_NET_RECV_FAILED
307-                                    : MBEDTLS_ERR_NET_SEND_FAILED;
308-             }
309-         }
280+         //  -------- mbedTLS BIO callbacks
310281
311282
312283        int  bio_send (const  void * buf, size_t  length) {
313284            if  (!open_)
314-                 return  0 ;
285+                 return  MBEDTLS_ERR_NET_CONN_RESET ;
315286            return  bio_return_value<false >(stream ().write_r (buf, length));
316287        }
317288
318289
290+         int  bio_recv (void * buf, size_t  length) {
291+             if  (!open_)
292+                 return  MBEDTLS_ERR_NET_CONN_RESET;
293+             return  bio_return_value<true >(stream ().read_r (buf, length));
294+         }
295+ 
296+ 
319297        int  bio_recv_timeout (void * buf, size_t  length, uint32_t  timeout) {
320298            if  (!open_)
321-                 return  0 ;
299+                 return  MBEDTLS_ERR_NET_CONN_RESET ;
322300            if  (timeout > 0 )
323301                stream ().read_timeout (chrono::milliseconds (timeout));
324302
325-             int  n = bio_return_value< true >( stream (). read_r ( buf, length) );
303+             int  n = bio_recv ( buf, length);
326304
327305            if  (timeout > 0 )
328306                stream ().read_timeout (chrono::hours (1000 ));   // FIXME: How do I turn off a timeout?
329307            return  n;
330308        }
331309
332310
333-         virtual  void  close () override  {
334-             if  (open_) {
335-                 mbedtls_ssl_close_notify (&ssl_);
311+         //  -------- error handling
312+ 
313+ 
314+         //  Translates mbedTLS error code to POSIX (errno)
315+         static  int  translate_mbed_err (int  mbedErr) {
316+             switch  (mbedErr) {
317+                 case  MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
318+                     return  0 ;
319+                 case  MBEDTLS_ERR_SSL_WANT_READ:
320+                 case  MBEDTLS_ERR_SSL_WANT_WRITE:
321+                     log (" >>> mbedtls_socket returning EWOULDBLOCK" 
322+                     return  EWOULDBLOCK;
323+                 case  MBEDTLS_ERR_NET_CONN_RESET:
324+                     return  ECONNRESET;
325+                 case  MBEDTLS_ERR_NET_RECV_FAILED:
326+                 case  MBEDTLS_ERR_NET_SEND_FAILED:
327+                     return  EIO;
328+                 default :
329+                     return  mbedErr;
330+             }
331+         }
332+ 
333+ 
334+         //  Handles an mbedTLS error return value during setup, closing me on error
335+         int  check_mbed_setup (int  ret, const  char  *fn) {
336+             if  (ret != 0 ) {
337+                 log_mbed_ret (ret, fn);
338+                 clear (ret); //  sets last_error
339+                 reset (); //  marks me as closed/invalid
340+                 stream ().close ();
336341                open_ = false ;
337342            }
338-             tls_socket::close () ;
343+             return  ret ;
339344        }
340345
346+ 
347+         //  Handles an mbedTLS read/write return value, storing any error in last_error
348+         inline  ssize_t  check_mbed_io (int  mbedResult) {
349+             int  result = translate_mbed_err (mbedResult);
350+             if  (result < 0 ) {
351+                 clear (result);     //  sets last_error
352+                 result = -1 ;
353+             }
354+             return  result;
355+         }
356+ 
357+ 
358+         //  Handles an mbedTLS read/write return value, converting it to an ioresult.
359+         static  inline  ioresult ioresult_from_mbed (int  mbedResult) {
360+             mbedResult = translate_mbed_err (mbedResult);
361+             return  mbedResult < 0  ? ioresult (0 , mbedResult) : ioresult (mbedResult, 0 );
362+         }
363+ 
364+ 
365+         //  Translates ioresult to an mbedTLS error code to return from a BIO function.
366+         template  <bool  reading>
367+         static  int  bio_return_value (ioresult result) {
368+             if  (result.count  >= 0 )
369+                 return  (int )result.count ;
370+             switch  (result.error ) {
371+                 case  EPIPE:
372+                 case  ECONNRESET:
373+                     return  MBEDTLS_ERR_NET_CONN_RESET;
374+                 case  EINTR:
375+                 case  EWOULDBLOCK:
376+ #if  defined(EAGAIN) && EAGAIN != EWOULDBLOCK    //  these are usually synonyms
377+                 case  EAGAIN:
378+ #endif 
379+                     log (" >>> BIO returning MBEDTLS_ERR_SSL_WANT_%s" " READ" " WRITE" 
380+                     return  reading ? MBEDTLS_ERR_SSL_WANT_READ
381+                                    : MBEDTLS_ERR_SSL_WANT_WRITE;
382+                 default :
383+                     return  reading ? MBEDTLS_ERR_NET_RECV_FAILED
384+                                    : MBEDTLS_ERR_NET_SEND_FAILED;
385+             }
386+         }
387+ 
388+ 
341389    };
342390
343391
0 commit comments