Skip to content

Commit 57ee5c9

Browse files
committed
feat(transport_ws): add support for per-message compression handshakes
1 parent ab14938 commit 57ee5c9

File tree

2 files changed

+263
-4
lines changed

2 files changed

+263
-4
lines changed

components/tcp_transport/include/esp_transport_ws.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ typedef enum ws_transport_opcodes {
2424
WS_TRANSPORT_OPCODES_CLOSE = 0x08,
2525
WS_TRANSPORT_OPCODES_PING = 0x09,
2626
WS_TRANSPORT_OPCODES_PONG = 0x0a,
27+
WS_TRANSPORT_OPCODES_COMPRESSED = 0x40,
2728
WS_TRANSPORT_OPCODES_FIN = 0x80,
2829
WS_TRANSPORT_OPCODES_NONE = 0x100, /*!< not a valid opcode to indicate no message previously received
2930
* from the API esp_transport_ws_get_read_opcode() */
@@ -48,6 +49,13 @@ typedef struct {
4849
* If false, only user frames are propagated, control frames are handled
4950
* automatically during read operations
5051
*/
52+
bool per_msg_compress; /*!< Hint the server to enable per-message compression (RFC7692) */
53+
int per_msg_client_deflate_window_bit; /*!< Hint the server Per-message deflate window bit 8 to 15; or leave 0 to let server decide */
54+
int per_msg_server_deflate_window_bit; /*!< Hint the server Per-message deflate window bit 8 to 15; or leave 0 to let server decide */
55+
bool per_msg_server_no_ctx_takeover; /*!< Hint the server to reset the compression stream on every WS frame on server side
56+
* True for a safer transfer, false for better performance */
57+
bool per_msg_client_no_ctx_takeover; /*!< Hint the server to reset the compression stream on every WS frame on client side
58+
* True for a safer transfer, false for better performance */
5159
} esp_transport_ws_config_t;
5260

5361
/**
@@ -184,6 +192,78 @@ int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t o
184192
*/
185193
bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t);
186194

195+
/**
196+
* @brief Returns the RSV1 flag (permessage-deflate) of the last read frame
197+
*
198+
* @param[in] t The transport handle
199+
*
200+
* @return
201+
* - true if the last read frame was compressed
202+
* - false otherwise
203+
*/
204+
bool esp_transport_ws_get_rsv1_flag(esp_transport_handle_t t);
205+
206+
/**
207+
* @brief Get per-message compression flag
208+
*
209+
* @param[in] t The transport handle
210+
*
211+
* @return
212+
* - true if per-message compression is enabled
213+
* - false if per-message compression is disabled
214+
*/
215+
bool esp_transport_ws_get_per_msg_compress(esp_transport_handle_t t);
216+
217+
/**
218+
* @brief Get client deflate window bit for per-message compression
219+
*
220+
* @param[in] t The transport handle
221+
*
222+
* @return
223+
* - client deflate window bit
224+
*/
225+
int esp_transport_ws_get_per_msg_client_deflate_window_bit(esp_transport_handle_t t);
226+
227+
/**
228+
* @brief Get server deflate window bit for per-message compression
229+
*
230+
* @param[in] t The transport handle
231+
*
232+
* @return
233+
* - server deflate window bit
234+
*/
235+
int esp_transport_ws_get_per_msg_server_deflate_window_bit(esp_transport_handle_t t);
236+
237+
/**
238+
* @brief Get server no context takeover flag for per-message compression
239+
*
240+
* If this is returned to be true, then the server-to-client's compression handle should be reset
241+
* on every frame transfer. If this is false, then the server-to-client's compression handle
242+
* should not be reset over the lifespan of this esp_transport_handle_t.
243+
*
244+
* @param[in] t The transport handle
245+
*
246+
* @return
247+
* - true if server no context takeover is enabled
248+
* - false if server no context takeover is disabled
249+
*/
250+
bool esp_transport_ws_get_per_msg_server_no_ctx_takeover(esp_transport_handle_t t);
251+
252+
/**
253+
* @brief Get client no context takeover flag for per-message compression
254+
*
255+
* If this is returned to be true, then the client-to-server's compression handle should be reset
256+
* on every frame transfer. If this is false, then the client-to-server's compression handle
257+
* should not be reset over the lifespan of this esp_transport_handle_t.
258+
*
259+
* @param[in] t The transport handle
260+
*
261+
* @return
262+
* - true if client no context takeover is enabled
263+
* - false if client no context takeover is disabled
264+
*/
265+
bool esp_transport_ws_get_per_msg_client_no_ctx_takeover(esp_transport_handle_t t);
266+
187267
/**
188268
* @brief Returns the HTTP status code of the websocket handshake
189269
*

components/tcp_transport/transport_ws.c

Lines changed: 183 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ static const char *TAG = "transport_ws";
2424

2525
#define WS_BUFFER_SIZE CONFIG_WS_BUFFER_SIZE
2626
#define WS_FIN 0x80
27+
#define WS_COMPRESSED 0x40
2728
#define WS_OPCODE_CONT 0x00
2829
#define WS_OPCODE_TEXT 0x01
2930
#define WS_OPCODE_BINARY 0x02
@@ -56,6 +57,7 @@ typedef struct {
5657
int payload_len; /*!< Total length of the payload */
5758
int bytes_remaining; /*!< Bytes left to read of the payload */
5859
bool header_received; /*!< Flag to indicate that a new message header was received */
60+
bool compressed; /*!< Per-message deflate compress flag (RSV1) */
5961
} ws_transport_frame_state_t;
6062

6163
typedef struct {
@@ -75,6 +77,11 @@ typedef struct {
7577
char *redir_host;
7678
char *response_header;
7779
size_t response_header_len;
80+
bool per_msg_compress;
81+
int per_msg_client_deflate_window_bit;
82+
int per_msg_server_deflate_window_bit;
83+
bool per_msg_server_no_ctx_takeover;
84+
bool per_msg_client_no_ctx_takeover;
7885
} transport_ws_t;
7986

8087
/**
@@ -201,6 +208,70 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
201208
#endif
202209

203210
size_t outlen = 0;
211+
char extension_header[168] = { 0 };
212+
if (ws->per_msg_compress) {
213+
int offset = 0;
214+
int ext_ret = snprintf(extension_header, sizeof(extension_header), "Sec-WebSocket-Extensions: permessage-deflate");
215+
if (ext_ret <= 0) {
216+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate");
217+
return -1;
218+
}
219+
220+
if (ws->per_msg_client_no_ctx_takeover) {
221+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_no_context_takeover");
222+
if (ext_ret <= 0) {
223+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate client_no_context_takeover");
224+
return -1;
225+
}
226+
227+
offset += ext_ret;
228+
}
229+
230+
if (ws->per_msg_server_no_ctx_takeover) {
231+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; server_no_context_takeover");
232+
if (ext_ret <= 0) {
233+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate server_no_context_takeover");
234+
return -1;
235+
}
236+
237+
offset += ext_ret;
238+
}
239+
240+
// If this is 0 then it means to let server decide the client window bit
241+
if (ws->per_msg_client_deflate_window_bit != 0) {
242+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_max_window_bits=%d", ws->per_msg_client_deflate_window_bit);
243+
} else {
244+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_max_window_bits");
245+
}
246+
247+
if (ext_ret <= 0) {
248+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate client_max_window_bits");
249+
return -1;
250+
}
251+
252+
offset += ext_ret;
253+
254+
// If this is 0 then it means to let server decide the server window bit
255+
if (ws->per_msg_server_deflate_window_bit != 0) {
256+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; server_max_window_bits=%d", ws->per_msg_server_deflate_window_bit);
257+
}
258+
259+
if (ext_ret <= 0) {
260+
ESP_LOGE(TAG, "Failed to write header to permessage-deflate server_max_window_bits");
261+
return -1;
262+
}
263+
264+
offset += ext_ret;
265+
266+
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "\r\n");
267+
if (ext_ret <= 0) {
268+
ESP_LOGE(TAG, "Failed to concat permessage-deflate header");
269+
return -1;
270+
}
271+
272+
extension_header[sizeof(extension_header) - 1] = '\0';
273+
}
274+
204275
esp_crypto_base64_encode(client_key, sizeof(client_key), &outlen, random_key, sizeof(random_key));
205276
int len = snprintf(ws->buffer, WS_BUFFER_SIZE,
206277
"GET %s HTTP/1.1\r\n"
@@ -209,10 +280,12 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
209280
"User-Agent: %s\r\n"
210281
"Upgrade: websocket\r\n"
211282
"Sec-WebSocket-Version: 13\r\n"
212-
"Sec-WebSocket-Key: %s\r\n",
283+
"Sec-WebSocket-Key: %s\r\n"
284+
"%s", // For "Sec-WebSocket-Extensions"
213285
ws->path,
214286
host, port, user_agent_ptr,
215-
client_key);
287+
client_key,
288+
extension_header);
216289
if (len <= 0 || len >= WS_BUFFER_SIZE) {
217290
ESP_LOGE(TAG, "Error in request generation, desired request len: %d, buffer size: %d", len, WS_BUFFER_SIZE);
218291
return -1;
@@ -306,6 +379,9 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
306379
}
307380
header_cursor += strlen("\r\n");
308381

382+
// If compression was requested, we need to check server response
383+
bool pmd_negotiated = false;
384+
309385
while(header_cursor < delim_ptr){
310386
const char * end_of_line = strnstr(header_cursor, "\r\n", header_len - (header_cursor - ws->buffer));
311387
if(!end_of_line){
@@ -332,6 +408,53 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
332408
server_key = header_cursor + header_sec_websocket_accept_len;
333409
server_key_len = line_len - header_sec_websocket_accept_len;
334410
}
411+
// Check for Sec-WebSocket-Extensions header
412+
else if (ws->per_msg_compress && line_len >= strlen("Sec-WebSocket-Extensions: ") && !strncasecmp(header_cursor, "Sec-WebSocket-Extensions: ", strlen("Sec-WebSocket-Extensions: "))) {
413+
const char* ext_params = header_cursor + strlen("Sec-WebSocket-Extensions: ");
414+
int ext_params_len = line_len - strlen("Sec-WebSocket-Extensions: ");
415+
ESP_LOGD(TAG, "Found Sec-WebSocket-Extensions: %.*s", ext_params_len, ext_params);
416+
417+
if (strcasestr(ext_params, "permessage-deflate")) {
418+
pmd_negotiated = true;
419+
420+
// Server must agree to context takeover settings
421+
if (!strcasestr(ext_params, "server_no_context_takeover")) {
422+
ws->per_msg_server_no_ctx_takeover = false;
423+
}
424+
if (!strcasestr(ext_params, "client_no_context_takeover")) {
425+
ws->per_msg_client_no_ctx_takeover = false;
426+
}
427+
428+
const char *smwb_str = "server_max_window_bits=";
429+
const char *found = strcasestr(ext_params, smwb_str);
430+
if (found) {
431+
char *endptr;
432+
long smwb = strtol(found + strlen(smwb_str), &endptr, 10);
433+
if (smwb < 8 || smwb > 15) {
434+
ESP_LOGE(TAG, "compression: Server Max Window Bits is invalid: %ld", smwb);
435+
return -1;
436+
}
437+
438+
ws->per_msg_server_deflate_window_bit = (int)smwb;
439+
} else {
440+
ws->per_msg_server_deflate_window_bit = 15;
441+
}
442+
443+
const char *cmwb_str = "client_max_window_bits=";
444+
found = strcasestr(ext_params, cmwb_str);
445+
if (found) {
446+
char *endptr;
447+
long cmwb = strtol(found + strlen(cmwb_str), &endptr, 10);
448+
449+
if (cmwb < 8 || cmwb > 15) {
450+
ESP_LOGE(TAG, "compression: Client Max Window Bits is invalid: %ld", cmwb);
451+
return -1;
452+
}
453+
454+
ws->per_msg_client_deflate_window_bit = (int)cmwb;
455+
}
456+
}
457+
}
335458
else if (ws->header_hook) {
336459
ws->header_hook(ws->header_user_context, header_cursor, line_len);
337460
}
@@ -349,6 +472,10 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
349472
header_cursor += strlen("\r\n");
350473
}
351474

475+
if (ws->per_msg_compress && !pmd_negotiated) {
476+
ws->per_msg_compress = false;
477+
}
478+
352479
if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) {
353480
if (location == NULL || location_len <= 0) {
354481
ESP_LOGE(TAG, "Location header not found");
@@ -575,6 +702,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
575702
ws->frame_state.header_received = true;
576703
ws->frame_state.fin = (*data_ptr & 0x80) != 0;
577704
ws->frame_state.opcode = (*data_ptr & 0x0F);
705+
ws->frame_state.compressed = (*data_ptr & 0x40) != 0; // RSV1 bit in the header
578706
data_ptr ++;
579707
mask = ((*data_ptr >> 7) & 0x01);
580708
payload_len = (*data_ptr & 0x7F);
@@ -979,14 +1107,65 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
9791107
}
9801108

9811109
ws->propagate_control_frames = config->propagate_control_frames;
1110+
ws->per_msg_compress = config->per_msg_compress;
1111+
ws->per_msg_client_no_ctx_takeover = config->per_msg_client_no_ctx_takeover;
1112+
ws->per_msg_server_no_ctx_takeover = config->per_msg_server_no_ctx_takeover;
1113+
1114+
if (config->per_msg_client_deflate_window_bit < 8 || config->per_msg_client_deflate_window_bit > 15) {
1115+
ws->per_msg_client_deflate_window_bit = 0;
1116+
} else {
1117+
ws->per_msg_client_deflate_window_bit = config->per_msg_client_deflate_window_bit;
1118+
}
1119+
1120+
if (config->per_msg_server_deflate_window_bit < 8 || config->per_msg_server_deflate_window_bit > 15) {
1121+
ws->per_msg_server_deflate_window_bit = 0;
1122+
} else {
1123+
ws->per_msg_server_deflate_window_bit = config->per_msg_server_deflate_window_bit;
1124+
}
9821125

9831126
return err;
9841127
}
9851128

9861129
bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t)
9871130
{
988-
transport_ws_t *ws = esp_transport_get_context_data(t);
989-
return ws->frame_state.fin;
1131+
transport_ws_t *ws = esp_transport_get_context_data(t);
1132+
return ws->frame_state.fin;
1133+
}
1134+
1135+
bool esp_transport_ws_get_rsv1_flag(esp_transport_handle_t t)
1136+
{
1137+
transport_ws_t *ws = esp_transport_get_context_data(t);
1138+
return ws->frame_state.compressed;
1139+
}
1140+
1141+
bool esp_transport_ws_get_per_msg_compress(esp_transport_handle_t t)
1142+
{
1143+
transport_ws_t *ws = esp_transport_get_context_data(t);
1144+
return ws->per_msg_compress;
1145+
}
1146+
1147+
int esp_transport_ws_get_per_msg_client_deflate_window_bit(esp_transport_handle_t t)
1148+
{
1149+
transport_ws_t *ws = esp_transport_get_context_data(t);
1150+
return ws->per_msg_client_deflate_window_bit;
1151+
}
1152+
1153+
int esp_transport_ws_get_per_msg_server_deflate_window_bit(esp_transport_handle_t t)
1154+
{
1155+
transport_ws_t *ws = esp_transport_get_context_data(t);
1156+
return ws->per_msg_server_deflate_window_bit;
1157+
}
1158+
1159+
bool esp_transport_ws_get_per_msg_server_no_ctx_takeover(esp_transport_handle_t t)
1160+
{
1161+
transport_ws_t *ws = esp_transport_get_context_data(t);
1162+
return ws->per_msg_server_no_ctx_takeover && ws->per_msg_compress;
1163+
}
1164+
1165+
bool esp_transport_ws_get_per_msg_client_no_ctx_takeover(esp_transport_handle_t t)
1166+
{
1167+
transport_ws_t *ws = esp_transport_get_context_data(t);
1168+
return ws->per_msg_client_no_ctx_takeover && ws->per_msg_compress;
9901169
}
9911170

9921171
int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t)

0 commit comments

Comments
 (0)