8
8
#include "lwip/netdb.h"
9
9
#include "lwip/dns.h"
10
10
11
+ #include "mbedtls/platform.h"
12
+ #include "mbedtls/net_sockets.h"
13
+ #include "mbedtls/esp_debug.h"
14
+ #include "mbedtls/ssl.h"
15
+ #include "mbedtls/entropy.h"
16
+ #include "mbedtls/ctr_drbg.h"
17
+ #include "mbedtls/error.h"
18
+ #include "mbedtls/certs.h"
19
+ #include "esp_crt_bundle.h"
20
+
11
21
#define LOG (fmt , ...) DMESG("SOCK: " fmt, ##__VA_ARGS__)
12
22
#if 1
13
23
#define LOGV (...) ((void)0)
19
29
20
30
static xQueueHandle sock_cmds ;
21
31
static xQueueHandle sock_events ;
32
+ static uint8_t sockbuf [128 ];
33
+
34
+ typedef struct {
35
+ mbedtls_ssl_context ssl ;
36
+ mbedtls_ctr_drbg_context ctr_drbg ; // rng
37
+ mbedtls_ssl_config conf ;
38
+ mbedtls_entropy_context entropy ;
39
+ mbedtls_net_context server_fd ;
40
+ bool is_tls ;
41
+ bool is_connected ;
42
+
43
+ #if 0
44
+ mbedtls_x509_crt cacert ;
45
+ mbedtls_x509_crt * cacert_ptr ;
46
+ mbedtls_x509_crt clientcert ;
47
+ mbedtls_pk_context clientkey ;
48
+ #endif
49
+ } sock_state_t ;
50
+
51
+ static sock_state_t _tls ;
22
52
23
53
typedef struct {
24
54
unsigned ev ;
@@ -49,6 +79,10 @@ static void push_event(unsigned event, const void *data, unsigned size) {
49
79
xQueueSend (sock_events , & evt , 20 );
50
80
}
51
81
82
+ static bool needs_io (int ret ) {
83
+ return (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE );
84
+ }
85
+
52
86
static void push_error (const char * msg ) {
53
87
push_event (JD_CONN_EV_ERROR , msg , strlen (msg ));
54
88
}
@@ -67,16 +101,34 @@ static void raise_error(const char *msg) {
67
101
}
68
102
}
69
103
70
- static int sock_create_and_connect (const char * hostname , const char * port_num ) {
104
+ static int mbedtls_print_error_msg (const char * fn , int error ) {
105
+ LOG ("%s returned -%x" , fn , - error );
106
+ mbedtls_strerror (error , (char * )sockbuf , sizeof (sockbuf ));
107
+ LOG (" %s" , sockbuf );
108
+ raise_error (fn );
109
+ return -1 ;
110
+ }
111
+
112
+ static int sock_create_and_connect (const char * hostname , int port ) {
71
113
struct addrinfo hints = {
72
114
.ai_family = AF_UNSPEC ,
73
115
.ai_socktype = SOCK_STREAM ,
74
116
};
75
117
struct addrinfo * result ;
76
118
77
- int s = getaddrinfo (hostname , port_num , & hints , & result );
119
+ bool is_tls = false;
120
+
121
+ if (port < 0 ) {
122
+ port = - port ;
123
+ is_tls = true;
124
+ }
125
+
126
+ char portbuf [10 ];
127
+ jd_sprintf (portbuf , sizeof (portbuf ), "%d" , port );
128
+
129
+ int s = getaddrinfo (hostname , portbuf , & hints , & result );
78
130
if (s ) {
79
- LOG ("getaddrinfo %s:%s : %d" , hostname , port_num , s );
131
+ LOG ("getaddrinfo %s:%d : %d" , hostname , port , s );
80
132
push_error ("can't resolve host" );
81
133
return -1 ;
82
134
}
@@ -94,7 +146,7 @@ static int sock_create_and_connect(const char *hostname, const char *port_num) {
94
146
}
95
147
96
148
if (rp -> ai_next == NULL ) {
97
- LOG ("connect %s:%s : %s" , hostname , port_num , strerror (errno ));
149
+ LOG ("connect %s:%d : %s" , hostname , port , strerror (errno ));
98
150
push_error ("can't connect" );
99
151
}
100
152
@@ -106,18 +158,72 @@ static int sock_create_and_connect(const char *hostname, const char *port_num) {
106
158
if (sockfd < 0 )
107
159
return sockfd ;
108
160
109
- LOG ("connected to %s:%s" , hostname , port_num );
110
- return sockfd ;
161
+ LOG ("connected to %s:%d" , hostname , port );
162
+
163
+ sock_fd = sockfd ;
164
+
165
+ if (!is_tls )
166
+ return 0 ;
167
+
168
+ sock_state_t * tls = & _tls ;
169
+ mbedtls_ssl_init (& tls -> ssl );
170
+ mbedtls_ctr_drbg_init (& tls -> ctr_drbg );
171
+ mbedtls_ssl_config_init (& tls -> conf );
172
+ mbedtls_entropy_init (& tls -> entropy );
173
+ tls -> is_tls = true;
174
+ tls -> server_fd .fd = sockfd ;
175
+
176
+ int ret ;
177
+
178
+ if ((ret = mbedtls_ssl_set_hostname (& tls -> ssl , hostname )) != 0 )
179
+ return mbedtls_print_error_msg ("mbedtls_ssl_set_hostname" , ret );
180
+
181
+ if ((ret = mbedtls_ssl_config_defaults (& tls -> conf , MBEDTLS_SSL_IS_CLIENT ,
182
+ MBEDTLS_SSL_TRANSPORT_STREAM ,
183
+ MBEDTLS_SSL_PRESET_DEFAULT )) != 0 )
184
+ return mbedtls_print_error_msg ("mbedtls_ssl_config_defaults" , ret );
185
+
186
+ mbedtls_ssl_conf_authmode (& tls -> conf , MBEDTLS_SSL_VERIFY_REQUIRED );
187
+ esp_crt_bundle_attach (& tls -> conf );
188
+
189
+ if ((ret = mbedtls_ctr_drbg_seed (& tls -> ctr_drbg , mbedtls_entropy_func , & tls -> entropy , NULL ,
190
+ 0 )) != 0 )
191
+ return mbedtls_print_error_msg ("mbedtls_ctr_drbg_seed" , ret );
192
+
193
+ mbedtls_ssl_conf_rng (& tls -> conf , mbedtls_ctr_drbg_random , & tls -> ctr_drbg );
194
+
195
+ // 2-warn 3-debug 4-verbose
196
+ // mbedtls_esp_enable_debug_log(&tls->conf, 3);
197
+
198
+ if ((ret = mbedtls_ssl_setup (& tls -> ssl , & tls -> conf )) != 0 )
199
+ return mbedtls_print_error_msg ("mbedtls_ssl_setup" , ret );
200
+
201
+ mbedtls_ssl_set_bio (& tls -> ssl , & tls -> server_fd , mbedtls_net_send , mbedtls_net_recv , NULL );
202
+
203
+ for (;;) {
204
+ ret = mbedtls_ssl_handshake (& tls -> ssl );
205
+ if (ret == 0 )
206
+ break ;
207
+
208
+ if (!needs_io (ret ))
209
+ return mbedtls_print_error_msg ("mbedtls_ssl_handshake" , ret );
210
+
211
+ vTaskDelay (10 );
212
+ }
213
+
214
+ mbedtls_net_set_nonblock (& tls -> server_fd );
215
+
216
+ LOG ("TLS handshake completed with %s:%d" , hostname , port );
217
+
218
+ return 0 ;
111
219
}
112
220
113
221
static void process_open (sock_cmd_t * cmd ) {
114
222
jd_tcpsock_close ();
115
- char * port_num = jd_sprintf_a ("%d" , cmd -> open .port );
116
- int r = sock_create_and_connect (cmd -> open .hostname , port_num );
117
- jd_free (port_num );
223
+ int r = sock_create_and_connect (cmd -> open .hostname , cmd -> open .port );
118
224
jd_free (cmd -> open .hostname );
119
- if (r > 0 ) {
120
- sock_fd = r ;
225
+ if (r == 0 ) {
226
+ _tls . is_connected = true ;
121
227
push_event (JD_CONN_EV_OPEN , NULL , 0 );
122
228
}
123
229
}
@@ -135,10 +241,36 @@ static int forced_write(int fd, const void *buf, size_t nbytes) {
135
241
return numread ;
136
242
}
137
243
244
+ static int sock_mbedtls_write (sock_state_t * tls , const uint8_t * data , size_t datalen ) {
245
+ JD_ASSERT (datalen < MBEDTLS_SSL_OUT_CONTENT_LEN );
246
+
247
+ size_t written = 0 ;
248
+ size_t write_len = datalen ;
249
+ while (written < datalen ) {
250
+ ssize_t ret = mbedtls_ssl_write (& tls -> ssl , data + written , write_len );
251
+ if (ret <= 0 ) {
252
+ if (ret != 0 && !needs_io (ret )) {
253
+ return mbedtls_print_error_msg ("mbedtls_ssl_write" , ret );
254
+ } else {
255
+ vTaskDelay (5 );
256
+ }
257
+ }
258
+ written += ret ;
259
+ write_len = datalen - written ;
260
+ }
261
+ return written ;
262
+ }
263
+
138
264
static void process_write (sock_cmd_t * cmd ) {
139
265
unsigned size = cmd -> write .size ;
140
266
uint8_t * buf = cmd -> write .data ;
141
- if (sock_fd ) {
267
+
268
+ sock_state_t * tls = & _tls ;
269
+
270
+ if (tls -> is_tls ) {
271
+ LOGV ("wrTLS %u b" , size );
272
+ sock_mbedtls_write (tls , buf , size );
273
+ } else if (sock_fd ) {
142
274
LOGV ("wr %u b" , size );
143
275
if (forced_write (sock_fd , buf , size ) != (int )size )
144
276
raise_error ("write error" );
@@ -151,6 +283,17 @@ void jd_tcpsock_close(void) {
151
283
close (sock_fd );
152
284
sock_fd = 0 ;
153
285
}
286
+
287
+ sock_state_t * tls = & _tls ;
288
+ tls -> is_connected = false;
289
+ if (tls -> is_tls ) {
290
+ // mbedtls_ssl_session_reset(&tls->ssl);
291
+ mbedtls_entropy_free (& tls -> entropy );
292
+ mbedtls_ssl_config_free (& tls -> conf );
293
+ mbedtls_ctr_drbg_free (& tls -> ctr_drbg );
294
+ mbedtls_ssl_free (& tls -> ssl );
295
+ tls -> is_tls = 0 ;
296
+ }
154
297
}
155
298
156
299
int jd_tcpsock_new (const char * hostname , int port ) {
@@ -200,6 +343,8 @@ static void worker_main(void *arg) {
200
343
}
201
344
202
345
void jd_tcpsock_init (void ) {
346
+ esp_log_level_set ("mbedtls" , ESP_LOG_DEBUG );
347
+
203
348
// The main task is at priority 1, so we're higher priority (run "more often").
204
349
// Timer task runs at much higher priority (~20).
205
350
unsigned stack_size = 4096 ;
@@ -210,32 +355,53 @@ void jd_tcpsock_init(void) {
210
355
}
211
356
212
357
void jd_tcpsock_process (void ) {
213
- static uint8_t sockbuf [128 ];
214
-
215
358
sock_event_t evt ;
216
359
while (xQueueReceive (sock_events , & evt , 0 )) {
217
360
jd_tcpsock_on_event (evt .ev , evt .data , evt .size );
218
361
}
219
362
220
- if (!sock_fd )
363
+ sock_state_t * tls = & _tls ;
364
+
365
+ if (!tls -> is_connected )
221
366
return ;
222
367
223
368
for (;;) {
224
- int r = recv (sock_fd , sockbuf , sizeof (sockbuf ), MSG_DONTWAIT );
225
- if (r == 0 ) {
226
- raise_error (NULL );
227
- return ;
228
- }
229
- if (r > 0 ) {
230
- LOGV ("rd %d" , r );
231
- jd_tcpsock_on_event (JD_CONN_EV_MESSAGE , sockbuf , r );
369
+ if (tls -> is_tls ) {
370
+ int ret = mbedtls_ssl_read (& tls -> ssl , sockbuf , sizeof (sockbuf ));
371
+
372
+ if (needs_io (ret ))
373
+ return ;
374
+
375
+ if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || ret == 0 ) {
376
+ raise_error (NULL );
377
+ return ;
378
+ }
379
+
380
+ if (ret < 0 ) {
381
+ mbedtls_print_error_msg ("mbedtls_ssl_read" , ret );
382
+ return ;
383
+ }
384
+
385
+ LOGV ("rdTLS %d" , ret );
386
+ jd_tcpsock_on_event (JD_CONN_EV_MESSAGE , sockbuf , ret );
232
387
} else {
233
- if (errno == EAGAIN || errno == EWOULDBLOCK )
234
- break ;
235
- else {
236
- raise_error ("recv error" );
388
+ int r = recv (sock_fd , sockbuf , sizeof (sockbuf ), MSG_DONTWAIT );
389
+ if (r == 0 ) {
390
+ raise_error (NULL );
237
391
return ;
238
392
}
393
+
394
+ if (r > 0 ) {
395
+ LOGV ("rd %d" , r );
396
+ jd_tcpsock_on_event (JD_CONN_EV_MESSAGE , sockbuf , r );
397
+ } else {
398
+ if (errno == EAGAIN || errno == EWOULDBLOCK )
399
+ break ;
400
+ else {
401
+ raise_error ("recv error" );
402
+ return ;
403
+ }
404
+ }
239
405
}
240
406
}
241
407
}
0 commit comments