diff --git a/src/inspector_socket.cc b/src/inspector_socket.cc index ad8b502da307c0..191ed5b0e5bd0d 100644 --- a/src/inspector_socket.cc +++ b/src/inspector_socket.cc @@ -8,8 +8,8 @@ #include "openssl/sha.h" // Sha-1 hash +#include #include -#include #define ACCEPT_KEY_LENGTH base64_encoded_size(20) #define BUFFER_GROWTH_CHUNK_SIZE 1024 @@ -63,7 +63,7 @@ class ProtocolHandler { virtual void Write(const std::vector data) = 0; virtual void CancelHandshake() = 0; - std::string GetHost(); + std::string GetHost() const; InspectorSocket* inspector() { return inspector_; @@ -160,6 +160,48 @@ static void generate_accept_string(const std::string& client_key, node::base64_encode(hash, sizeof(hash), *buffer, sizeof(*buffer)); } +static bool IsOneOf(const std::string& host, + const std::vector& hosts) { + for (const std::string& candidate : hosts) { + if (node::StringEqualNoCase(host.data(), candidate.data())) + return true; + } + return false; +} + +static std::string TrimPort(const std::string& host) { + size_t last_colon_pos = host.rfind(":"); + if (last_colon_pos == std::string::npos) + return host; + size_t bracket = host.rfind("]"); + if (bracket == std::string::npos || last_colon_pos > bracket) + return host.substr(0, last_colon_pos); + return host; +} + +static bool IsIPAddress(const std::string& host) { + if (host.length() >= 4 && host.front() == '[' && host.back() == ']') + return true; + int quads = 0; + for (char c : host) { + if (c == '.') + quads++; + else if (!isdigit(c)) + return false; + } + return quads == 3; +} + +// This is a value coming from the interface, it can only be IPv4 or IPv6 +// address string. +static bool IsIPv4Localhost(const std::string& host) { + std::string v6_tunnel_prefix = "::ffff:"; + if (host.substr(0, v6_tunnel_prefix.length()) == v6_tunnel_prefix) + return IsIPv4Localhost(host.substr(v6_tunnel_prefix.length())); + std::string localhost_net = "127."; + return host.substr(0, localhost_net.length()) == localhost_net; +} + // Constants for hybi-10 frame format. typedef int OpCode; @@ -298,7 +340,6 @@ static ws_decode_result decode_frame_hybi17(const std::vector& buffer, return closed ? FRAME_CLOSE : FRAME_OK; } - // WS protocol class WsHandler : public ProtocolHandler { public: @@ -400,17 +441,16 @@ class WsHandler : public ProtocolHandler { // HTTP protocol class HttpEvent { public: - HttpEvent(const std::string& path, bool upgrade, - bool isGET, const std::string& ws_key) : path(path), - upgrade(upgrade), - isGET(isGET), - ws_key(ws_key) { } + HttpEvent(const std::string& path, bool upgrade, bool isGET, + const std::string& ws_key, const std::string& host) + : path(path), upgrade(upgrade), isGET(isGET), ws_key(ws_key), + host(host) { } std::string path; bool upgrade; bool isGET; std::string ws_key; - std::string current_header_; + std::string host; }; class HttpHandler : public ProtocolHandler { @@ -472,18 +512,17 @@ class HttpHandler : public ProtocolHandler { std::vector events; std::swap(events, events_); for (const HttpEvent& event : events) { - bool shouldContinue = event.isGET && !event.upgrade; - if (!event.isGET) { + if (!IsAllowedHost(event.host) || !event.isGET) { CancelHandshake(); + return; } else if (!event.upgrade) { delegate()->OnHttpGet(event.path); } else if (event.ws_key.empty()) { CancelHandshake(); + return; } else { delegate()->OnSocketUpgrade(event.path, event.ws_key); } - if (!shouldContinue) - return; } } @@ -504,16 +543,9 @@ class HttpHandler : public ProtocolHandler { } static int OnHeaderValue(http_parser* parser, const char* at, size_t length) { - static const char SEC_WEBSOCKET_KEY_HEADER[] = "Sec-WebSocket-Key"; HttpHandler* handler = From(parser); handler->parsing_value_ = true; - if (handler->current_header_.size() == - sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1 && - node::StringEqualNoCaseN(handler->current_header_.data(), - SEC_WEBSOCKET_KEY_HEADER, - sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1)) { - handler->ws_key_.append(at, length); - } + handler->headers_[handler->current_header_].append(at, length); return 0; } @@ -540,23 +572,53 @@ class HttpHandler : public ProtocolHandler { static int OnMessageComplete(http_parser* parser) { // Event needs to be fired after the parser is done. HttpHandler* handler = From(parser); - handler->events_.push_back(HttpEvent(handler->path_, parser->upgrade, - parser->method == HTTP_GET, - handler->ws_key_)); + handler->events_.push_back( + HttpEvent(handler->path_, parser->upgrade, parser->method == HTTP_GET, + handler->HeaderValue("Sec-WebSocket-Key"), + handler->HeaderValue("Host"))); handler->path_ = ""; - handler->ws_key_ = ""; handler->parsing_value_ = false; + handler->headers_.clear(); handler->current_header_ = ""; - return 0; } + std::string HeaderValue(const std::string& header) const { + bool header_found = false; + std::string value; + for (const auto& header_value : headers_) { + if (node::StringEqualNoCaseN(header_value.first.data(), header.data(), + header.length())) { + if (header_found) + return ""; + value = header_value.second; + header_found = true; + } + } + return value; + } + + bool IsAllowedHost(const std::string& host_with_port) const { + std::string host = TrimPort(host_with_port); + if (host.empty()) + return false; + if (IsIPAddress(host)) + return true; + std::string socket_host = GetHost(); + if (IsIPv4Localhost(socket_host)) { + return IsOneOf(host, { "localhost" }); + } else if (socket_host == "::1") { + return IsOneOf(host, { "localhost", "localhost6" }); + } + return true; + } + bool parsing_value_; http_parser parser_; http_parser_settings parser_settings; std::vector events_; std::string current_header_; - std::string ws_key_; + std::map headers_; std::string path_; }; @@ -579,7 +641,7 @@ InspectorSocket::Delegate* ProtocolHandler::delegate() { return tcp_->delegate(); } -std::string ProtocolHandler::GetHost() { +std::string ProtocolHandler::GetHost() const { char ip[INET6_ADDRSTRLEN]; sockaddr_storage addr; int len = sizeof(addr); @@ -622,8 +684,6 @@ TcpHolder::Pointer TcpHolder::Accept( if (err == 0) { return { result, DisconnectAndDispose }; } else { - fprintf(stderr, "[%s:%d@%s]\n", __FILE__, __LINE__, __FUNCTION__); - delete result; return { nullptr, nullptr }; } diff --git a/test/cctest/test_inspector_socket.cc b/test/cctest/test_inspector_socket.cc index debbc957379375..ae6e1231c4c3a1 100644 --- a/test/cctest/test_inspector_socket.cc +++ b/test/cctest/test_inspector_socket.cc @@ -205,7 +205,7 @@ struct read_expects { }; static const char HANDSHAKE_REQ[] = "GET /ws/path HTTP/1.1\r\n" - "Host: localhost:9222\r\n" + "Host: localhost:9229\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key: aaa==\r\n" @@ -504,7 +504,7 @@ TEST_F(InspectorSocketTest, ExtraTextBeforeRequest) { TEST_F(InspectorSocketTest, RequestWithoutKey) { const char BROKEN_REQUEST[] = "GET / HTTP/1.1\r\n" - "Host: localhost:9222\r\n" + "Host: localhost:9229\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Version: 13\r\n\r\n"; @@ -619,24 +619,23 @@ TEST_F(InspectorSocketTest, ReportsHttpGet) { delegate->SetDelegate(ReportsHttpGet_handshake); const char GET_REQ[] = "GET /some/path HTTP/1.1\r\n" - "Host: localhost:9222\r\n" + "Host: localhost:9229\r\n" "Sec-WebSocket-Key: aaa==\r\n" "Sec-WebSocket-Version: 13\r\n\r\n"; send_in_chunks(GET_REQ, sizeof(GET_REQ) - 1); expect_nothing_on_client(); - const char WRITE_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n" - "Host: localhost:9222\r\n\r\n"; + "Host: localhost:9229\r\n\r\n"; send_in_chunks(WRITE_REQUEST, sizeof(WRITE_REQUEST) - 1); expect_on_client(TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1); const char GET_REQS[] = "GET /some/path2 HTTP/1.1\r\n" - "Host: localhost:9222\r\n" + "Host: localhost:9229\r\n" "Sec-WebSocket-Key: aaa==\r\n" "Sec-WebSocket-Version: 13\r\n\r\n" "GET /close HTTP/1.1\r\n" - "Host: localhost:9222\r\n" + "Host: localhost:9229\r\n" "Sec-WebSocket-Key: aaa==\r\n" "Sec-WebSocket-Version: 13\r\n\r\n"; send_in_chunks(GET_REQS, sizeof(GET_REQS) - 1); @@ -696,7 +695,7 @@ static void GetThenHandshake_handshake(enum inspector_handshake_event state, TEST_F(InspectorSocketTest, GetThenHandshake) { delegate->SetDelegate(GetThenHandshake_handshake); const char WRITE_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n" - "Host: localhost:9222\r\n\r\n"; + "Host: localhost:9229\r\n\r\n"; send_in_chunks(WRITE_REQUEST, sizeof(WRITE_REQUEST) - 1); expect_on_client(TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1); @@ -826,4 +825,36 @@ TEST_F(InspectorSocketTest, NoCloseResponseFromClient) { delegate->WaitForDispose(); } +static bool delegate_called = false; + +void shouldnt_be_called(enum inspector_handshake_event state, + const std::string& path, bool* cont) { + delegate_called = true; +} + +void expect_failure_no_delegate(const std::string& request) { + delegate->SetDelegate(shouldnt_be_called); + delegate_called = false; + send_in_chunks(request.c_str(), request.length()); + expect_handshake_failure(); + SPIN_WHILE(delegate != nullptr); + ASSERT_FALSE(delegate_called); +} + +TEST_F(InspectorSocketTest, HostCheckedForGET) { + const char GET_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n" + "Host: notlocalhost:9229\r\n\r\n"; + expect_failure_no_delegate(GET_REQUEST); +} + +TEST_F(InspectorSocketTest, HostCheckedForUPGRADE) { + const char UPGRADE_REQUEST[] = "GET /ws/path HTTP/1.1\r\n" + "Host: nonlocalhost:9229\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: aaa==\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n"; + expect_failure_no_delegate(UPGRADE_REQUEST); +} + } // anonymous namespace