Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions awscrt/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,13 @@ class OnConnectionSetupData:
This is None if the connection failed before receiving an HTTP response.
"""

# TODO: hook this up in C
# handshake_response_body: bytes = None
# """The HTTP response body, if you're interested.
handshake_response_body: bytes = None
"""The HTTP response body, if you're interested.

# This is only present if the server sent a full HTTP response rejecting the handshake.
# It is not present if the connection succeeded,
# or the connection failed for other reasons.
# """
This is only present if the server sent a full HTTP response rejecting the handshake.
It is not present if the connection succeeded,
or the connection failed for other reasons.
"""


@dataclass
Expand Down Expand Up @@ -442,21 +441,18 @@ def _on_connection_setup(
error_code,
websocket_binding,
handshake_response_status,
handshake_response_headers):
handshake_response_headers,
handshake_response_body):

cbdata = OnConnectionSetupData()
if error_code:
cbdata.exception = awscrt.exceptions.from_code(error_code)
else:
cbdata.websocket = WebSocket(websocket_binding)

if handshake_response_status != -1:
cbdata.handshake_response_status = handshake_response_status

if handshake_response_headers is not None:
cbdata.handshake_response_headers = handshake_response_headers

# TODO: get C to pass handshake_response_body
cbdata.handshake_response_status = handshake_response_status
cbdata.handshake_response_headers = handshake_response_headers
cbdata.handshake_response_body = handshake_response_body

# Do not let exceptions from the user's callback bubble up any further.
try:
Expand Down
2 changes: 1 addition & 1 deletion crt/aws-c-mqtt
59 changes: 36 additions & 23 deletions source/websocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
static const char *s_websocket_capsule_name = "aws_websocket";

static void s_websocket_on_connection_setup(
struct aws_websocket *websocket,
int error_code,
int handshake_response_status,
const struct aws_http_header *handshake_response_header_array,
size_t num_handshake_response_headers,
const struct aws_websocket_on_connection_setup_data *setup,
void *user_data);

static void s_websocket_on_connection_shutdown(struct aws_websocket *websocket, int error_code, void *user_data);
Expand Down Expand Up @@ -177,15 +173,11 @@ PyObject *aws_py_websocket_client_connect(PyObject *self, PyObject *args) {
* that we can't actually check (and so may not actually work).
*/
static void s_websocket_on_connection_setup(
struct aws_websocket *websocket,
int error_code,
int handshake_response_status,
const struct aws_http_header *handshake_response_header_array,
size_t num_handshake_response_headers,
const struct aws_websocket_on_connection_setup_data *setup,
void *user_data) {

/* sanity check: websocket XOR error_code is set. both cannot be set. both cannot be unset */
AWS_FATAL_ASSERT((websocket != NULL) ^ (error_code != 0));
AWS_FATAL_ASSERT((setup->websocket != NULL) ^ (setup->error_code != 0));

/* userdata is _WebSocketCore */
PyObject *websocket_core_py = user_data;
Expand All @@ -194,17 +186,26 @@ static void s_websocket_on_connection_setup(
PyGILState_STATE state = PyGILState_Ensure();

PyObject *websocket_binding_py = NULL;
if (websocket) {
websocket_binding_py = PyCapsule_New(websocket, s_websocket_capsule_name, s_websocket_capsule_destructor);
if (setup->websocket) {
websocket_binding_py =
PyCapsule_New(setup->websocket, s_websocket_capsule_name, s_websocket_capsule_destructor);
AWS_FATAL_ASSERT(websocket_binding_py && "capsule allocation failed");
}

/* Any of the handshake_response variables could be NULL */

PyObject *status_code_py = NULL;
if (setup->handshake_response_status != NULL) {
status_code_py = PyLong_FromLong(*setup->handshake_response_status);
AWS_FATAL_ASSERT(status_code_py && "status code allocation failed");
}

PyObject *headers_py = NULL;
if (num_handshake_response_headers > 0) {
headers_py = PyList_New((Py_ssize_t)num_handshake_response_headers);
if (setup->handshake_response_header_array != NULL) {
headers_py = PyList_New((Py_ssize_t)setup->num_handshake_response_headers);
AWS_FATAL_ASSERT(headers_py && "header list allocation failed");
for (size_t i = 0; i < num_handshake_response_headers; ++i) {
const struct aws_http_header *header_i = &handshake_response_header_array[i];
for (size_t i = 0; i < setup->num_handshake_response_headers; ++i) {
const struct aws_http_header *header_i = &setup->handshake_response_header_array[i];
PyObject *tuple_py = PyTuple_New(2);
AWS_FATAL_ASSERT(tuple_py && "header tuple allocation failed");

Expand All @@ -220,14 +221,24 @@ static void s_websocket_on_connection_setup(
}
}

PyObject *body_py = NULL;
if (setup->handshake_response_body != NULL) {
/* AWS APIs are fine with NULL as the address of a 0-length array,
* but python APIs requires that it be non-NULL */
const char *ptr = setup->handshake_response_body->ptr ? (const char *)setup->handshake_response_body->ptr : "";
body_py = PyBytes_FromStringAndSize(ptr, (Py_ssize_t)setup->handshake_response_body->len);
AWS_FATAL_ASSERT(body_py && "response body allocation failed");
}

PyObject *result = PyObject_CallMethod(
websocket_core_py,
"_on_connection_setup",
"(iOiO)",
error_code,
websocket_binding_py ? websocket_binding_py : Py_None,
handshake_response_status,
headers_py ? headers_py : Py_None);
"(iOOOO)",
/* i */ setup->error_code,
/* O */ websocket_binding_py ? websocket_binding_py : Py_None,
/* O */ status_code_py ? status_code_py : Py_None,
/* O */ headers_py ? headers_py : Py_None,
/* O */ body_py ? body_py : Py_None);

if (result) {
Py_DECREF(result);
Expand All @@ -240,10 +251,12 @@ static void s_websocket_on_connection_setup(
}

Py_XDECREF(websocket_binding_py);
Py_XDECREF(status_code_py);
Py_XDECREF(headers_py);
Py_XDECREF(body_py);

/* If setup failed, there will be no further callbacks, so release _WebSocketCore */
if (error_code != 0) {
if (setup->error_code != 0) {
Py_DECREF(websocket_core_py);
}

Expand Down
6 changes: 6 additions & 0 deletions test/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def test_connect(self):
self.assertEqual(101, setup_data.handshake_response_status)
# check for response header we know should be there
self.assertIn(("Upgrade", "websocket"), setup_data.handshake_response_headers)
# a successful handshake response has no body
self.assertIsNone(setup_data.handshake_response_body)

# now close the WebSocket
setup_data.websocket.close()
Expand Down Expand Up @@ -287,6 +289,7 @@ def test_connect_failure_without_response(self):
# nothing responded, so there should be no "handshake response"
self.assertIsNone(setup_data.handshake_response_status)
self.assertIsNone(setup_data.handshake_response_headers)
self.assertIsNone(setup_data.handshake_response_body)

# ensure that on_connection_shutdown does NOT fire
sleep(0.5)
Expand Down Expand Up @@ -317,6 +320,9 @@ def test_connect_failure_with_response(self):
# check the HTTP response data
self.assertGreaterEqual(setup_data.handshake_response_status, 400)
self.assertIsNotNone(setup_data.handshake_response_headers)
self.assertIsNotNone(setup_data.handshake_response_body)
# check that body is a valid string
self.assertGreater(len(setup_data.handshake_response_body.decode()), 0)

def test_exception_in_setup_callback_closes_websocket(self):
with WebSocketServer(self.host, self.port) as server:
Expand Down