@@ -59,17 +59,18 @@ static size_t s_handler_initial_window_size(struct aws_channel_handler *handler)
5959static size_t s_handler_message_overhead (struct aws_channel_handler * handler );
6060static void s_handler_destroy (struct aws_channel_handler * handler );
6161static struct aws_http_stream * s_new_client_request_stream (const struct aws_http_request_options * options );
62+ static void s_connection_close (struct aws_http_connection * connection_base );
6263static void s_stream_destroy (struct aws_http_stream * stream_base );
6364static void s_stream_update_window (struct aws_http_stream * stream , size_t increment_size );
64- static void s_decoder_on_request (
65+ static int s_decoder_on_request (
6566 enum aws_http_method method_enum ,
6667 const struct aws_byte_cursor * method_str ,
6768 const struct aws_byte_cursor * uri ,
6869 void * user_data );
69- static void s_decoder_on_response (int status_code , void * user_data );
70- static bool s_decoder_on_header (const struct aws_http_decoded_header * header , void * user_data );
71- static bool s_decoder_on_body (const struct aws_byte_cursor * data , bool finished , void * user_data );
72- static void s_decoder_on_done (void * user_data );
70+ static int s_decoder_on_response (int status_code , void * user_data );
71+ static int s_decoder_on_header (const struct aws_http_decoded_header * header , void * user_data );
72+ static int s_decoder_on_body (const struct aws_byte_cursor * data , bool finished , void * user_data );
73+ static int s_decoder_on_done (void * user_data );
7374
7475static struct aws_http_connection_vtable s_connection_vtable = {
7576 .channel_handler_vtable =
@@ -84,6 +85,7 @@ static struct aws_http_connection_vtable s_connection_vtable = {
8485 },
8586
8687 .new_client_request_stream = s_new_client_request_stream ,
88+ .close = s_connection_close ,
8789};
8890
8991static const struct aws_http_stream_vtable s_stream_vtable = {
@@ -108,6 +110,9 @@ struct h1_connection {
108110 /* Single task used for issuing window updates from off-thread */
109111 struct aws_channel_task window_update_task ;
110112
113+ /* Task used once during shutdown. */
114+ struct aws_channel_task shutdown_delay_task ;
115+
111116 /* Only the event-loop thread may touch this data */
112117 struct {
113118 /* List of streams being worked on. */
@@ -183,7 +188,8 @@ struct h1_stream {
183188};
184189
185190/**
186- * Called when something goes wrong internally which should result in the channel shutting down.
191+ * Internal function for shutting down the connection.
192+ * If connection is already shutting down, this call has no effect.
187193 */
188194static void s_shutdown_connection (struct h1_connection * connection , int error_code ) {
189195 assert (aws_channel_thread_is_callers_thread (connection -> base .channel_slot -> channel ));
@@ -209,8 +215,36 @@ static void s_shutdown_connection(struct h1_connection *connection, int error_co
209215 connection -> thread_data .is_shutting_down = true;
210216 connection -> thread_data .shutdown_error_code = error_code ;
211217
218+ /* Delay the call to aws_channel_shutdown().
219+ * This ensures that a user calling aws_http_connection_close() won't have completion callbacks
220+ * firing before aws_http_connection_close() has even returned. */
221+ aws_channel_schedule_task_now (connection -> base .channel_slot -> channel , & connection -> shutdown_delay_task );
222+ }
223+ }
224+
225+ static void s_shutdown_delay_task (struct aws_channel_task * task , void * arg , enum aws_task_status status ) {
226+ (void )task ;
227+ struct h1_connection * connection = arg ;
228+
229+ if (status == AWS_TASK_STATUS_RUN_READY ) {
212230 /* If channel is already shutting down, this call has no effect */
213- aws_channel_shutdown (connection -> base .channel_slot -> channel , error_code );
231+ aws_channel_shutdown (connection -> base .channel_slot -> channel , connection -> thread_data .shutdown_error_code );
232+ }
233+ }
234+
235+ /**
236+ * Public function for closing connection.
237+ * If connection is already shutting down, this call has no effect.
238+ */
239+ static void s_connection_close (struct aws_http_connection * connection_base ) {
240+ struct h1_connection * connection = AWS_CONTAINER_OF (connection_base , struct h1_connection , base );
241+
242+ if (aws_channel_thread_is_callers_thread (connection_base -> channel_slot -> channel )) {
243+ /* Invoke internal function so connection ceases work immediately */
244+ s_shutdown_connection (connection , AWS_ERROR_SUCCESS );
245+ } else {
246+ /* Not on thread, so tell channel to shut down, which will result in connection shutting down. */
247+ aws_channel_shutdown (connection_base -> channel_slot -> channel , AWS_ERROR_SUCCESS );
214248 }
215249}
216250
@@ -841,7 +875,7 @@ static void s_outgoing_stream_task(struct aws_channel_task *task, void *arg, enu
841875 s_shutdown_connection (connection , aws_last_error ());
842876}
843877
844- static void s_decoder_on_request (
878+ static int s_decoder_on_request (
845879 enum aws_http_method method_enum ,
846880 const struct aws_byte_cursor * method_str ,
847881 const struct aws_byte_cursor * uri ,
@@ -886,17 +920,16 @@ static void s_decoder_on_request(
886920
887921 incoming_stream -> base .incoming_request_method = method_enum ;
888922
889- return ;
890- error :
891-
892- /* TODO: all decoder callbacks should be able to stop decoder, so we don't keep churning in the case of errors.
893- * There's some fishy stuff where callbacks assume current_incoming_stream is a valid ptr, but that's only the case
894- * while things are working */
923+ /* No user callbacks, so we're not checking for shutdown */
924+ return AWS_OP_SUCCESS ;
895925
896- s_shutdown_connection (connection , aws_last_error ());
926+ error :
927+ err = aws_last_error ();
928+ s_shutdown_connection (connection , err );
929+ return aws_raise_error (err );
897930}
898931
899- static void s_decoder_on_response (int status_code , void * user_data ) {
932+ static int s_decoder_on_response (int status_code , void * user_data ) {
900933 struct h1_connection * connection = user_data ;
901934
902935 AWS_LOGF_TRACE (
@@ -907,9 +940,12 @@ static void s_decoder_on_response(int status_code, void *user_data) {
907940 aws_http_status_text (status_code ));
908941
909942 connection -> thread_data .incoming_stream -> base .incoming_response_status = status_code ;
943+
944+ /* No user callbacks, so we're not checking for shutdown */
945+ return AWS_OP_SUCCESS ;
910946}
911947
912- static bool s_decoder_on_header (const struct aws_http_decoded_header * header , void * user_data ) {
948+ static int s_decoder_on_header (const struct aws_http_decoded_header * header , void * user_data ) {
913949 struct h1_connection * connection = user_data ;
914950 struct h1_stream * incoming_stream = connection -> thread_data .incoming_stream ;
915951
@@ -920,8 +956,6 @@ static bool s_decoder_on_header(const struct aws_http_decoded_header *header, vo
920956 AWS_BYTE_CURSOR_PRI (header -> name_data ),
921957 AWS_BYTE_CURSOR_PRI (header -> value_data ));
922958
923- /* TODO: worth buffering up headers and delivering all at once? In clumps? */
924-
925959 /* TODO? how to support trailing headers? distinct cb? invoke same cb again? */
926960
927961 if (incoming_stream -> base .on_incoming_headers ) {
@@ -933,13 +967,18 @@ static bool s_decoder_on_header(const struct aws_http_decoded_header *header, vo
933967 incoming_stream -> base .on_incoming_headers (& incoming_stream -> base , & deliver , 1 , incoming_stream -> base .user_data );
934968 }
935969
936- return true;
970+ /* Stop decoding if user callback shut down the connection. */
971+ if (connection -> thread_data .is_shutting_down ) {
972+ return aws_raise_error (AWS_ERROR_HTTP_CONNECTION_CLOSED );
973+ }
974+
975+ return AWS_OP_SUCCESS ;
937976}
938977
939- static void s_mark_head_done (struct h1_stream * incoming_stream ) {
978+ static int s_mark_head_done (struct h1_stream * incoming_stream ) {
940979 /* Bail out if we've already done this */
941980 if (incoming_stream -> is_incoming_head_done ) {
942- return ;
981+ return AWS_OP_SUCCESS ;
943982 }
944983
945984 incoming_stream -> is_incoming_head_done = true;
@@ -964,16 +1003,26 @@ static void s_mark_head_done(struct h1_stream *incoming_stream) {
9641003 incoming_stream -> base .on_incoming_header_block_done (
9651004 & incoming_stream -> base , has_incoming_body , incoming_stream -> base .user_data );
9661005 }
1006+
1007+ /* Stop decoding if user callback shut down the connection. */
1008+ if (connection -> thread_data .is_shutting_down ) {
1009+ return aws_raise_error (AWS_ERROR_HTTP_CONNECTION_CLOSED );
1010+ }
1011+
1012+ return AWS_OP_SUCCESS ;
9671013}
9681014
969- static bool s_decoder_on_body (const struct aws_byte_cursor * data , bool finished , void * user_data ) {
1015+ static int s_decoder_on_body (const struct aws_byte_cursor * data , bool finished , void * user_data ) {
9701016 (void )finished ;
9711017
9721018 struct h1_connection * connection = user_data ;
9731019 struct h1_stream * incoming_stream = connection -> thread_data .incoming_stream ;
9741020 assert (incoming_stream );
9751021
976- s_mark_head_done (incoming_stream );
1022+ int err = s_mark_head_done (incoming_stream );
1023+ if (err ) {
1024+ return AWS_OP_ERR ;
1025+ }
9771026
9781027 AWS_LOGF_TRACE (
9791028 AWS_LS_HTTP_STREAM , "id=%p: Incoming body: %zu bytes received." , (void * )& incoming_stream -> base , data -> len );
@@ -998,16 +1047,24 @@ static bool s_decoder_on_body(const struct aws_byte_cursor *data, bool finished,
9981047 }
9991048 }
10001049
1001- return true;
1050+ /* Stop decoding if user callback shut down the connection. */
1051+ if (connection -> thread_data .is_shutting_down ) {
1052+ return aws_raise_error (AWS_ERROR_HTTP_CONNECTION_CLOSED );
1053+ }
1054+
1055+ return AWS_OP_SUCCESS ;
10021056}
10031057
1004- static void s_decoder_on_done (void * user_data ) {
1058+ static int s_decoder_on_done (void * user_data ) {
10051059 struct h1_connection * connection = user_data ;
10061060 struct h1_stream * incoming_stream = connection -> thread_data .incoming_stream ;
10071061 assert (incoming_stream );
10081062
10091063 /* Ensure head was marked done */
1010- s_mark_head_done (incoming_stream );
1064+ int err = s_mark_head_done (incoming_stream );
1065+ if (err ) {
1066+ return AWS_OP_ERR ;
1067+ }
10111068
10121069 incoming_stream -> is_incoming_message_done = true;
10131070
@@ -1018,6 +1075,12 @@ static void s_decoder_on_done(void *user_data) {
10181075
10191076 s_update_incoming_stream_ptr (connection );
10201077 }
1078+
1079+ /* Report success even if user's on_complete() callback shuts down on the connection.
1080+ * We don't want it to look like something went wrong while decoding.
1081+ * The decode() function returns after each message completes,
1082+ * and we won't call decode() again if the connection has been shut down */
1083+ return AWS_OP_SUCCESS ;
10211084}
10221085
10231086/* Common new() logic for server & client */
@@ -1040,6 +1103,7 @@ static struct h1_connection *s_connection_new(struct aws_allocator *alloc) {
10401103
10411104 aws_channel_task_init (& connection -> outgoing_stream_task , s_outgoing_stream_task , connection );
10421105 aws_channel_task_init (& connection -> window_update_task , s_update_window_task , connection );
1106+ aws_channel_task_init (& connection -> shutdown_delay_task , s_shutdown_delay_task , connection );
10431107 aws_linked_list_init (& connection -> thread_data .stream_list );
10441108
10451109 int err = aws_mutex_init (& connection -> synced_data .lock );
0 commit comments