Skip to content

Commit

Permalink
tcp_transport/esp_tls: Use common TCP transport to reduce code duplic…
Browse files Browse the repository at this point in the history
…ation

For high level review of the changes.
  • Loading branch information
david-cermak committed Feb 16, 2021
1 parent 391d7bf commit 2c28fff
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 356 deletions.
4 changes: 2 additions & 2 deletions components/esp-tls/esp_tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c
switch (tls->conn_state) {
case ESP_TLS_INIT:
tls->sockfd = -1;
if (cfg != NULL) {
if (cfg != NULL && cfg->is_plain_tcp == false) {
#ifdef CONFIG_ESP_TLS_USING_MBEDTLS
mbedtls_net_init(&tls->server_fd);
#endif
Expand All @@ -286,7 +286,7 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c
ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ESP_TLS_ERR_TYPE_ESP, esp_ret);
return -1;
}
if (!cfg) {
if (tls->is_tls == false) {
tls->read = tcp_read;
tls->write = tcp_write;
ESP_LOGD(TAG, "non-tls connection established");
Expand Down
1 change: 1 addition & 0 deletions components/esp-tls/esp_tls.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ typedef struct esp_tls_cfg {
bundle for server verification, must be enabled in menuconfig */

void *ds_data; /*!< Pointer for digital signature peripheral context */
bool is_plain_tcp;
} esp_tls_cfg_t;

#ifdef CONFIG_ESP_TLS_SERVER
Expand Down
1 change: 0 additions & 1 deletion components/tcp_transport/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
idf_component_register(SRCS "transport.c"
"transport_ssl.c"
"transport_tcp.c"
"transport_ws.c"
"transport_utils.c"
INCLUDE_DIRS "include"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

typedef int (*get_socket_func)(esp_transport_handle_t t);

struct transport_esp_tls;

/**
* Transport layer structure, which will provide functions, basic properties for transport types
*/
Expand All @@ -40,6 +42,7 @@ struct esp_transport_item_t {
struct esp_transport_error_s* error_handle; /*!< Error handle (based on esp-tls error handle)
* extended with transport's specific errors */
esp_transport_keep_alive_t *keep_alive_cfg; /*!< TCP keep-alive config */
struct transport_esp_tls *foundation_transport;

STAILQ_ENTRY(esp_transport_item_t) next;
};
Expand Down Expand Up @@ -86,4 +89,6 @@ int esp_transport_get_socket(esp_transport_handle_t t);
*/
void esp_transport_capture_errno(esp_transport_handle_t t, int sock_errno);

struct transport_esp_tls* esp_transport_init_foundation(void);

#endif //_ESP_TRANSPORT_INTERNAL_H_
7 changes: 6 additions & 1 deletion components/tcp_transport/transport.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "esp_transport.h"
#include "esp_transport_internal.h"
#include "esp_transport_utils.h"
#include "esp_tls_errors.h"

static const char *TAG = "TRANSPORT";

Expand All @@ -43,12 +42,15 @@ struct esp_transport_error_s {
*/
STAILQ_HEAD(esp_transport_list_t, esp_transport_item_t);

struct transport_esp_tls;

/**
* Internal transport structure holding list of transports and other data common to all transports
*/
typedef struct esp_transport_internal {
struct esp_transport_list_t list; /*!< List of transports */
struct esp_transport_error_s* error_handle; /*!< Pointer to the transport error container */
struct transport_esp_tls *foundation_transport;
} esp_transport_internal_t;

static esp_transport_handle_t esp_transport_get_default_parent(esp_transport_handle_t t)
Expand All @@ -65,6 +67,7 @@ esp_transport_list_handle_t esp_transport_list_init(void)
ESP_TRANSPORT_MEM_CHECK(TAG, transport, return NULL);
STAILQ_INIT(&transport->list);
transport->error_handle = calloc(1, sizeof(struct esp_transport_error_s));
transport->foundation_transport = esp_transport_init_foundation();
return transport;
}

Expand All @@ -79,6 +82,7 @@ esp_err_t esp_transport_list_add(esp_transport_list_handle_t h, esp_transport_ha
STAILQ_INSERT_TAIL(&h->list, t, next);
// Each transport in a list to share the same error tracker
t->error_handle = h->error_handle;
t->foundation_transport = h->foundation_transport;
return ESP_OK;
}

Expand All @@ -103,6 +107,7 @@ esp_err_t esp_transport_list_destroy(esp_transport_list_handle_t h)
{
esp_transport_list_clean(h);
free(h->error_handle);
free(h->foundation_transport); // TODO: make it destroy foundation
free(h);
return ESP_OK;
}
Expand Down
89 changes: 60 additions & 29 deletions components/tcp_transport/transport_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ typedef enum {
/**
* mbedtls specific transport data
*/
typedef struct {
typedef struct transport_esp_tls {
esp_tls_t *tls;
esp_tls_cfg_t cfg;
bool ssl_initialized;
Expand All @@ -48,7 +48,7 @@ static int ssl_close(esp_transport_handle_t t);

static int ssl_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (ssl->conn_state == TRANS_SSL_INIT) {
ssl->cfg.timeout_ms = timeout_ms;
ssl->cfg.non_block = true;
Expand All @@ -67,7 +67,7 @@ static int ssl_connect_async(esp_transport_handle_t t, const char *host, int por

static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;

ssl->cfg.timeout_ms = timeout_ms;
ssl->ssl_initialized = true;
Expand All @@ -83,9 +83,29 @@ static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int
return 0;
}

static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
{
transport_ssl_t *ssl = t->foundation_transport;

ssl->cfg.timeout_ms = timeout_ms;
ssl->cfg.is_plain_tcp = true;
ssl->ssl_initialized = true;
ssl->tls = esp_tls_init();
if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) <= 0) {
ESP_LOGE(TAG, "Failed to open a new connection");
esp_transport_set_errors(t, ssl->tls->error_handle);
esp_tls_conn_destroy(ssl->tls);
ssl->tls = NULL;
return -1;
}

return 0;
}


static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
int ret = -1;
int remain = 0;
struct timeval timeout;
Expand Down Expand Up @@ -114,7 +134,7 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)

static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
int ret = -1;
struct timeval timeout;
fd_set writeset;
Expand All @@ -138,7 +158,7 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
{
int poll, ret;
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;

if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
Expand All @@ -155,7 +175,7 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int
static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
{
int poll, ret;
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;

if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
return poll;
Expand All @@ -178,7 +198,7 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout
static int ssl_close(esp_transport_handle_t t)
{
int ret = -1;
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (ssl->ssl_initialized) {
ret = esp_tls_conn_destroy(ssl->tls);
ssl->conn_state = TRANS_SSL_INIT;
Expand All @@ -189,31 +209,31 @@ static int ssl_close(esp_transport_handle_t t)

static int ssl_destroy(esp_transport_handle_t t)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
esp_transport_close(t);
free(ssl);
return 0;
}

void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.use_global_ca_store = true;
}
}

void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.psk_hint_key = psk_hint_key;
}
}

void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.cacert_pem_buf = (void *)data;
ssl->cfg.cacert_pem_bytes = len + 1;
Expand All @@ -222,7 +242,7 @@ void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data,

void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.cacert_buf = (void *)data;
ssl->cfg.cacert_bytes = len;
Expand All @@ -231,7 +251,7 @@ void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *d

void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.clientcert_pem_buf = (void *)data;
ssl->cfg.clientcert_pem_bytes = len + 1;
Expand All @@ -240,7 +260,7 @@ void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char

void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.clientcert_buf = (void *)data;
ssl->cfg.clientcert_bytes = len;
Expand All @@ -249,7 +269,7 @@ void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const

void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.clientkey_pem_buf = (void *)data;
ssl->cfg.clientkey_pem_bytes = len + 1;
Expand All @@ -258,7 +278,7 @@ void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char

void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.clientkey_password = (void *)password;
ssl->cfg.clientkey_password_len = password_len;
Expand All @@ -267,7 +287,7 @@ void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const c

void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.clientkey_buf = (void *)data;
ssl->cfg.clientkey_bytes = len;
Expand All @@ -276,23 +296,23 @@ void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const c

void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.alpn_protos = alpn_protos;
}
}

void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.skip_common_name = true;
}
}

void esp_transport_ssl_use_secure_element(esp_transport_handle_t t)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) {
ssl->cfg.use_secure_element = true;
}
Expand All @@ -311,8 +331,8 @@ static int ssl_get_socket(esp_transport_handle_t t)

void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data)
{
transport_ssl_t *ssl = esp_transport_get_context_data(t);
if (t && ssl) {
transport_ssl_t *ssl = t->foundation_transport;
if (t && ssl) { // TODO: check t NULL first!
ssl->cfg.ds_data = ds_data;
}
}
Expand All @@ -328,14 +348,25 @@ void esp_transport_ssl_set_keep_alive(esp_transport_handle_t t, esp_transport_ke
esp_transport_handle_t esp_transport_ssl_init(void)
{
esp_transport_handle_t t = esp_transport_init();
transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t));
ESP_TRANSPORT_MEM_CHECK(TAG, ssl, {
esp_transport_destroy(t);
return NULL;
});
esp_transport_set_context_data(t, ssl);
esp_transport_set_context_data(t, NULL);
esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
esp_transport_set_async_connect_func(t, ssl_connect_async);
t->_get_socket = ssl_get_socket;
return t;
}

struct transport_esp_tls* esp_transport_init_foundation(void)
{
transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t));
return ssl;
}

esp_transport_handle_t esp_transport_tcp_init(void)
{
esp_transport_handle_t t = esp_transport_init();
esp_transport_set_context_data(t, NULL);
esp_transport_set_func(t, tcp_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
esp_transport_set_async_connect_func(t, ssl_connect_async); // TODO: tcp_connect_async()
t->_get_socket = ssl_get_socket;
return t;
}
Loading

0 comments on commit 2c28fff

Please sign in to comment.