Skip to content

Commit 10e5b93

Browse files
committed
feat(transport_ws): add support for per-message compression handshakes
1 parent ab14938 commit 10e5b93

File tree

2 files changed

+265
-4
lines changed

2 files changed

+265
-4
lines changed

components/tcp_transport/include/esp_transport_ws.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ typedef enum ws_transport_opcodes {
2424
WS_TRANSPORT_OPCODES_CLOSE = 0x08,
2525
WS_TRANSPORT_OPCODES_PING = 0x09,
2626
WS_TRANSPORT_OPCODES_PONG = 0x0a,
27+
WS_TRANSPORT_OPCODES_COMPRESSED = 0x40,
2728
WS_TRANSPORT_OPCODES_FIN = 0x80,
2829
WS_TRANSPORT_OPCODES_NONE = 0x100, /*!< not a valid opcode to indicate no message previously received
2930
* from the API esp_transport_ws_get_read_opcode() */
@@ -48,6 +49,13 @@ typedef struct {
4849
* If false, only user frames are propagated, control frames are handled
4950
* automatically during read operations
5051
*/
52+
bool per_msg_compress; /*!< Hint the server to enable per-message compression (RFC7692) */
53+
int per_msg_client_deflate_window_bit; /*!< Hint the server Per-message deflate window bit 8 to 15; or leave 0 to let server decide */
54+
int per_msg_server_deflate_window_bit; /*!< Hint the server Per-message deflate window bit 8 to 15; or leave 0 to let server decide */
55+
bool per_msg_server_no_ctx_takeover; /*!< Hint the server to reset the compression stream on every WS frame on server side
56+
* True for a safer transfer, false for better performance */
57+
bool per_msg_client_no_ctx_takeover; /*!< Hint the server to reset the compression stream on every WS frame on client side
58+
* True for a safer transfer, false for better performance */
5159
} esp_transport_ws_config_t;
5260

5361
/**
@@ -184,6 +192,78 @@ int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t o
184192
*/
185193
bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t);
186194

195+
/**
196+
* @brief Returns the RSV1 flag (permessage-deflate) of the last read frame
197+
*
198+
* @param[in] t The transport handle
199+
*
200+
* @return
201+
* - true if the last read frame was compressed
202+
* - false otherwise
203+
*/
204+
bool esp_transport_ws_get_rsv1_flag(esp_transport_handle_t t);
205+
206+
/**
207+
* @brief Get per-message compression flag
208+
*
209+
* @param[in] t The transport handle
210+
*
211+
* @return
212+
* - true if per-message compression is enabled
213+
* - false if per-message compression is disabled
214+
*/
215+
bool esp_transport_ws_get_per_msg_compress(esp_transport_handle_t t);
216+
217+
/**
218+
* @brief Get client deflate window bit for per-message compression
219+
*
220+
* @param[in] t The transport handle
221+
*
222+
* @return
223+
* - client deflate window bit
224+
*/
225+
int esp_transport_ws_get_per_msg_client_deflate_window_bit(esp_transport_handle_t t);
226+
227+
/**
228+
* @brief Get server deflate window bit for per-message compression
229+
*
230+
* @param[in] t The transport handle
231+
*
232+
* @return
233+
* - server deflate window bit
234+
*/
235+
int esp_transport_ws_get_per_msg_server_deflate_window_bit(esp_transport_handle_t t);
236+
237+
/**
238+
* @brief Get server no context takeover flag for per-message compression
239+
*
240+
* If this is returned to be true, then the server-to-client's compression handle should be reset
241+
* on every frame transfer. If this is false, then the server-to-client's compression handle
242+
* should not be reset over the lifespan of this esp_transport_handle_t.
243+
*
244+
* @param[in] t The transport handle
245+
*
246+
* @return
247+
* - true if server no context takeover is enabled
248+
* - false if server no context takeover is disabled
249+
*/
250+
bool esp_transport_ws_get_per_msg_server_no_ctx_takeover(esp_transport_handle_t t);
251+
252+
/**
253+
* @brief Get client no context takeover flag for per-message compression
254+
*
255+
* If this is returned to be true, then the client-to-server's compression handle should be reset
256+
* on every frame transfer. If this is false, then the client-to-server's compression handle
257+
* should not be reset over the lifespan of this esp_transport_handle_t.
258+
*
259+
* @param[in] t The transport handle
260+
*
261+
* @return
262+
* - true if client no context takeover is enabled
263+
* - false if client no context takeover is disabled
264+
*/
265+
bool esp_transport_ws_get_per_msg_client_no_ctx_takeover(esp_transport_handle_t t);
266+
187267
/**
188268
* @brief Returns the HTTP status code of the websocket handshake
189269
*

components/tcp_transport/transport_ws.c

Lines changed: 185 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ static const char *TAG = "transport_ws";
2424

2525
#define WS_BUFFER_SIZE CONFIG_WS_BUFFER_SIZE
2626
#define WS_FIN 0x80
27+
#define WS_COMPRESSED 0x40
2728
#define WS_OPCODE_CONT 0x00
2829
#define WS_OPCODE_TEXT 0x01
2930
#define WS_OPCODE_BINARY 0x02
@@ -56,6 +57,7 @@ typedef struct {
5657
int payload_len; /*!< Total length of the payload */
5758
int bytes_remaining; /*!< Bytes left to read of the payload */
5859
bool header_received; /*!< Flag to indicate that a new message header was received */
60+
bool compressed; /*!< Per-message deflate compress flag (RSV1) */
5961
} ws_transport_frame_state_t;
6062

6163
typedef struct {
@@ -75,6 +77,11 @@ typedef struct {
7577
char *redir_host;
7678
char *response_header;
7779
size_t response_header_len;
80+
bool per_msg_compress;
81+
int per_msg_client_deflate_window_bit;
82+
int per_msg_server_deflate_window_bit;
83+
bool per_msg_server_no_ctx_takeover;
84+
bool per_msg_client_no_ctx_takeover;
7885
} transport_ws_t;
7986

8087
/**
@@ -201,6 +208,72 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
201208
#endif
202209

203210
size_t outlen = 0;
211+
char extension_header[168] = { 0 };
212+
if (ws->per_msg_compress) {
213+
int offset = 0;
214+
int ext_ret = snprintf(extension_header, sizeof(extension_header), "Sec-WebSocket-Extensions: permessage-deflate");
215+
if (ext_ret <= 0) {
216+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate");
217+
return -1;
218+
}
219+
220+
offset += ext_ret;
221+
222+
if (ws->per_msg_client_no_ctx_takeover) {
223+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_no_context_takeover");
224+
if (ext_ret <= 0) {
225+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate client_no_context_takeover");
226+
return -1;
227+
}
228+
229+
offset += ext_ret;
230+
}
231+
232+
if (ws->per_msg_server_no_ctx_takeover) {
233+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; server_no_context_takeover");
234+
if (ext_ret <= 0) {
235+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate server_no_context_takeover");
236+
return -1;
237+
}
238+
239+
offset += ext_ret;
240+
}
241+
242+
// If this is 0 then it means to let server decide the client window bit
243+
if (ws->per_msg_client_deflate_window_bit != 0) {
244+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_max_window_bits=%d", ws->per_msg_client_deflate_window_bit);
245+
} else {
246+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_max_window_bits");
247+
}
248+
249+
if (ext_ret <= 0) {
250+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate client_max_window_bits");
251+
return -1;
252+
}
253+
254+
offset += ext_ret;
255+
256+
// If this is 0 then it means to let server decide the server window bit
257+
if (ws->per_msg_server_deflate_window_bit != 0) {
258+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; server_max_window_bits=%d", ws->per_msg_server_deflate_window_bit);
259+
260+
if (ext_ret <= 0) {
261+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate server_max_window_bits");
262+
return -1;
263+
}
264+
265+
offset += ext_ret;
266+
}
267+
268+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "\r\n");
269+
if (ext_ret <= 0) {
270+
ESP_LOGE(TAG, "Failed to concat permessage-deflate header");
271+
return -1;
272+
}
273+
274+
extension_header[sizeof(extension_header) - 1] = '\0';
275+
}
276+
204277
esp_crypto_base64_encode(client_key, sizeof(client_key), &outlen, random_key, sizeof(random_key));
205278
int len = snprintf(ws->buffer, WS_BUFFER_SIZE,
206279
"GET %s HTTP/1.1\r\n"
@@ -209,10 +282,12 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
209282
"User-Agent: %s\r\n"
210283
"Upgrade: websocket\r\n"
211284
"Sec-WebSocket-Version: 13\r\n"
212-
"Sec-WebSocket-Key: %s\r\n",
285+
"Sec-WebSocket-Key: %s\r\n"
286+
"%s", // For "Sec-WebSocket-Extensions"
213287
ws->path,
214288
host, port, user_agent_ptr,
215-
client_key);
289+
client_key,
290+
extension_header);
216291
if (len <= 0 || len >= WS_BUFFER_SIZE) {
217292
ESP_LOGE(TAG, "Error in request generation, desired request len: %d, buffer size: %d", len, WS_BUFFER_SIZE);
218293
return -1;
@@ -306,6 +381,9 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
306381
}
307382
header_cursor += strlen("\r\n");
308383

384+
// If compression was requested, we need to check server response
385+
bool pmd_negotiated = false;
386+
309387
while(header_cursor < delim_ptr){
310388
const char * end_of_line = strnstr(header_cursor, "\r\n", header_len - (header_cursor - ws->buffer));
311389
if(!end_of_line){
@@ -332,6 +410,53 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
332410
server_key = header_cursor + header_sec_websocket_accept_len;
333411
server_key_len = line_len - header_sec_websocket_accept_len;
334412
}
413+
// Check for Sec-WebSocket-Extensions header
414+
else if (ws->per_msg_compress && line_len >= strlen("Sec-WebSocket-Extensions: ") && !strncasecmp(header_cursor, "Sec-WebSocket-Extensions: ", strlen("Sec-WebSocket-Extensions: "))) {
415+
const char* ext_params = header_cursor + strlen("Sec-WebSocket-Extensions: ");
416+
int ext_params_len = line_len - strlen("Sec-WebSocket-Extensions: ");
417+
ESP_LOGD(TAG, "Found Sec-WebSocket-Extensions: %.*s", ext_params_len, ext_params);
418+
419+
if (strcasestr(ext_params, "permessage-deflate")) {
420+
pmd_negotiated = true;
421+
422+
// Server must agree to context takeover settings
423+
if (!strcasestr(ext_params, "server_no_context_takeover")) {
424+
ws->per_msg_server_no_ctx_takeover = false;
425+
}
426+
if (!strcasestr(ext_params, "client_no_context_takeover")) {
427+
ws->per_msg_client_no_ctx_takeover = false;
428+
}
429+
430+
const char *smwb_str = "server_max_window_bits=";
431+
const char *found = strcasestr(ext_params, smwb_str);
432+
if (found) {
433+
char *endptr;
434+
long smwb = strtol(found + strlen(smwb_str), &endptr, 10);
435+
if (smwb < 8 || smwb > 15) {
436+
ESP_LOGE(TAG, "compression: Server Max Window Bits is invalid: %ld", smwb);
437+
return -1;
438+
}
439+
440+
ws->per_msg_server_deflate_window_bit = (int)smwb;
441+
} else {
442+
ws->per_msg_server_deflate_window_bit = 15;
443+
}
444+
445+
const char *cmwb_str = "client_max_window_bits=";
446+
found = strcasestr(ext_params, cmwb_str);
447+
if (found) {
448+
char *endptr;
449+
long cmwb = strtol(found + strlen(cmwb_str), &endptr, 10);
450+
451+
if (cmwb < 8 || cmwb > 15) {
452+
ESP_LOGE(TAG, "compression: Client Max Window Bits is invalid: %ld", cmwb);
453+
return -1;
454+
}
455+
456+
ws->per_msg_client_deflate_window_bit = (int)cmwb;
457+
}
458+
}
459+
}
335460
else if (ws->header_hook) {
336461
ws->header_hook(ws->header_user_context, header_cursor, line_len);
337462
}
@@ -349,6 +474,10 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
349474
header_cursor += strlen("\r\n");
350475
}
351476

477+
if (ws->per_msg_compress && !pmd_negotiated) {
478+
ws->per_msg_compress = false;
479+
}
480+
352481
if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) {
353482
if (location == NULL || location_len <= 0) {
354483
ESP_LOGE(TAG, "Location header not found");
@@ -575,6 +704,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
575704
ws->frame_state.header_received = true;
576705
ws->frame_state.fin = (*data_ptr & 0x80) != 0;
577706
ws->frame_state.opcode = (*data_ptr & 0x0F);
707+
ws->frame_state.compressed = (*data_ptr & 0x40) != 0; // RSV1 bit in the header
578708
data_ptr ++;
579709
mask = ((*data_ptr >> 7) & 0x01);
580710
payload_len = (*data_ptr & 0x7F);
@@ -979,14 +1109,65 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
9791109
}
9801110

9811111
ws->propagate_control_frames = config->propagate_control_frames;
1112+
ws->per_msg_compress = config->per_msg_compress;
1113+
ws->per_msg_client_no_ctx_takeover = config->per_msg_client_no_ctx_takeover;
1114+
ws->per_msg_server_no_ctx_takeover = config->per_msg_server_no_ctx_takeover;
1115+
1116+
if (config->per_msg_client_deflate_window_bit < 8 || config->per_msg_client_deflate_window_bit > 15) {
1117+
ws->per_msg_client_deflate_window_bit = 0;
1118+
} else {
1119+
ws->per_msg_client_deflate_window_bit = config->per_msg_client_deflate_window_bit;
1120+
}
1121+
1122+
if (config->per_msg_server_deflate_window_bit < 8 || config->per_msg_server_deflate_window_bit > 15) {
1123+
ws->per_msg_server_deflate_window_bit = 0;
1124+
} else {
1125+
ws->per_msg_server_deflate_window_bit = config->per_msg_server_deflate_window_bit;
1126+
}
9821127

9831128
return err;
9841129
}
9851130

9861131
bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t)
9871132
{
988-
transport_ws_t *ws = esp_transport_get_context_data(t);
989-
return ws->frame_state.fin;
1133+
transport_ws_t *ws = esp_transport_get_context_data(t);
1134+
return ws->frame_state.fin;
1135+
}
1136+
1137+
bool esp_transport_ws_get_rsv1_flag(esp_transport_handle_t t)
1138+
{
1139+
transport_ws_t *ws = esp_transport_get_context_data(t);
1140+
return ws->frame_state.compressed;
1141+
}
1142+
1143+
bool esp_transport_ws_get_per_msg_compress(esp_transport_handle_t t)
1144+
{
1145+
transport_ws_t *ws = esp_transport_get_context_data(t);
1146+
return ws->per_msg_compress;
1147+
}
1148+
1149+
int esp_transport_ws_get_per_msg_client_deflate_window_bit(esp_transport_handle_t t)
1150+
{
1151+
transport_ws_t *ws = esp_transport_get_context_data(t);
1152+
return ws->per_msg_client_deflate_window_bit;
1153+
}
1154+
1155+
int esp_transport_ws_get_per_msg_server_deflate_window_bit(esp_transport_handle_t t)
1156+
{
1157+
transport_ws_t *ws = esp_transport_get_context_data(t);
1158+
return ws->per_msg_server_deflate_window_bit;
1159+
}
1160+
1161+
bool esp_transport_ws_get_per_msg_server_no_ctx_takeover(esp_transport_handle_t t)
1162+
{
1163+
transport_ws_t *ws = esp_transport_get_context_data(t);
1164+
return ws->per_msg_server_no_ctx_takeover && ws->per_msg_compress;
1165+
}
1166+
1167+
bool esp_transport_ws_get_per_msg_client_no_ctx_takeover(esp_transport_handle_t t)
1168+
{
1169+
transport_ws_t *ws = esp_transport_get_context_data(t);
1170+
return ws->per_msg_client_no_ctx_takeover && ws->per_msg_compress;
9901171
}
9911172

9921173
int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t)

0 commit comments

Comments
 (0)