diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 12c99bad3..3d85a75f6 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -480,12 +480,16 @@ parse_header(State=#state{opts=Opts, frag_state=FragState, extensions=Extensions websocket_close(State, HandlerState, {error, badframe}) end. -parse_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions}, +parse_payload(State=#state{opts=Opts, frag_state=FragState, utf8_state=Incomplete, extensions=Extensions}, HandlerState, ParseState=#ps_payload{ type=Type, len=Len, mask_key=MaskKey, rsv=Rsv, unmasked=Unmasked, unmasked_len=UnmaskedLen}, Data) -> + MaxFrameSize = case maps:get(max_frame_size, Opts, infinity) of + infinity -> infinity; + MaxFrameSize0 -> MaxFrameSize0 - UnmaskedLen + end, case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, - Type, Len, FragState, Extensions, Rsv) of + Type, Len, FragState, Extensions#{max_inflate_size => MaxFrameSize}, Rsv) of {ok, CloseCode, Payload, Utf8State, Rest} -> dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState, ParseState#ps_payload{unmasked= <>, diff --git a/test/ws_SUITE.erl b/test/ws_SUITE.erl index 3b7433939..9c1f88082 100644 --- a/test/ws_SUITE.erl +++ b/test/ws_SUITE.erl @@ -203,6 +203,25 @@ do_ws_version(Socket) -> {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. +ws_deflate_max_frame_size_close(Config) -> + doc("Server closes connection when decompressed frame size exceeds max_frame_size option"), + %% max_frame_size is set to 8 bytes in ws_max_frame_size. + {ok, Socket, Headers} = do_handshake("/ws_max_frame_size", + "Sec-WebSocket-Extensions: permessage-deflate\r\n", Config), + {_, "permessage-deflate"} = lists:keyfind("sec-websocket-extensions", 1, Headers), + Mask = 16#11223344, + Z = zlib:open(), + zlib:deflateInit(Z, best_compression, deflated, -15, 8, default), + CompressedData0 = iolist_to_binary(zlib:deflate(Z, <<0:800>>, sync)), + CompressedData = binary:part(CompressedData0, 0, byte_size(CompressedData0) - 4), + MaskedData = do_mask(CompressedData, Mask, <<>>), + Len = byte_size(MaskedData), + true = Len < 8, + ok = gen_tcp:send(Socket, << 1:1, 1:1, 0:2, 1:4, 1:1, Len:7, Mask:32, MaskedData/binary >>), + {ok, << 1:1, 0:3, 8:4, 0:1, 2:7, 1009:16 >>} = gen_tcp:recv(Socket, 0, 6000), + {error, closed} = gen_tcp:recv(Socket, 0, 6000), + ok. + ws_deflate_opts_client_context_takeover(Config) -> doc("Handler is configured with client context takeover enabled."), {ok, _, Headers1} = do_handshake("/ws_deflate_opts?client_context_takeover", diff --git a/test/ws_SUITE_data/ws_max_frame_size.erl b/test/ws_SUITE_data/ws_max_frame_size.erl index 3d8149717..76df0b066 100644 --- a/test/ws_SUITE_data/ws_max_frame_size.erl +++ b/test/ws_SUITE_data/ws_max_frame_size.erl @@ -5,7 +5,7 @@ -export([websocket_info/2]). init(Req, State) -> - {cowboy_websocket, Req, State, #{max_frame_size => 8}}. + {cowboy_websocket, Req, State, #{max_frame_size => 8, compress => true}}. websocket_handle({text, Data}, State) -> {[{text, Data}], State};