From 276542128ed32fa8168b2d6ae30d0e34b412d63e Mon Sep 17 00:00:00 2001 From: Alyssa Wilk Date: Thu, 4 Nov 2021 08:44:35 -0400 Subject: [PATCH] http: making sure upstream ALPN is consistently accessible. Signed-off-by: Alyssa Wilk --- envoy/ssl/connection.h | 5 ++ envoy/stream_info/stream_info.h | 14 +---- .../common/quic/envoy_quic_client_session.cc | 7 +-- .../common/quic/envoy_quic_server_session.cc | 4 +- .../quic_filter_manager_connection_impl.cc | 11 ++-- .../quic_filter_manager_connection_impl.h | 60 ++++++++++++++++++- source/common/router/router.cc | 11 ++-- source/common/router/upstream_request.cc | 1 - source/common/stream_info/stream_info_impl.h | 5 -- .../tls/connection_info_impl_base.cc | 12 ++++ .../tls/connection_info_impl_base.h | 2 + .../stream_info/stream_info_impl_test.cc | 3 - test/common/stream_info/test_util.h | 3 - test/integration/BUILD | 1 + test/integration/filters/BUILD | 15 +++++ .../filters/stream_info_to_headers_filter.cc | 36 +++++++++++ .../multiplexed_upstream_integration_test.cc | 9 +++ test/mocks/ssl/mocks.h | 1 + test/mocks/stream_info/mocks.h | 2 - 19 files changed, 156 insertions(+), 46 deletions(-) create mode 100644 test/integration/filters/stream_info_to_headers_filter.cc diff --git a/envoy/ssl/connection.h b/envoy/ssl/connection.h index c0bca9c6ab32..501a99b80f28 100644 --- a/envoy/ssl/connection.h +++ b/envoy/ssl/connection.h @@ -137,6 +137,11 @@ class ConnectionInfo { * connection. **/ virtual const std::string& tlsVersion() const PURE; + + /** + * @return std::string the protocol negotiated via ALPN. + **/ + virtual const std::string& alpn() const PURE; }; using ConnectionInfoConstSharedPtr = std::shared_ptr; diff --git a/envoy/stream_info/stream_info.h b/envoy/stream_info/stream_info.h index 6c287cfd83e8..5c4c2b473b61 100644 --- a/envoy/stream_info/stream_info.h +++ b/envoy/stream_info/stream_info.h @@ -322,25 +322,15 @@ class StreamInfo { virtual uint64_t bytesReceived() const PURE; /** - * @return the protocol of the downstream stream. + * @return the protocol of the request. */ virtual absl::optional protocol() const PURE; /** - * @param protocol the downstream stream's protocol. + * @param protocol the request's protocol. */ virtual void protocol(Http::Protocol protocol) PURE; - /** - * @return the protocol of the upstream stream. - */ - virtual absl::optional upstreamProtocol() const PURE; - - /** - * @param protocol the upstream stream's protocol. - */ - virtual void upstreamProtocol(Http::Protocol protocol) PURE; - /** * @return the response code. */ diff --git a/source/common/quic/envoy_quic_client_session.cc b/source/common/quic/envoy_quic_client_session.cc index 4cf14c31db3b..587cecc0719e 100644 --- a/source/common/quic/envoy_quic_client_session.cc +++ b/source/common/quic/envoy_quic_client_session.cc @@ -31,13 +31,12 @@ EnvoyQuicClientSession::EnvoyQuicClientSession( uint32_t send_buffer_limit, EnvoyQuicCryptoClientStreamFactoryInterface& crypto_stream_factory, QuicStatNames& quic_stat_names, Stats::Scope& scope) : QuicFilterManagerConnectionImpl(*connection, connection->connection_id(), dispatcher, - send_buffer_limit), + send_buffer_limit, + std::make_shared(*this)), quic::QuicSpdyClientSession(config, supported_versions, connection.release(), server_id, crypto_config.get(), push_promise_index), crypto_config_(crypto_config), crypto_stream_factory_(crypto_stream_factory), - quic_stat_names_(quic_stat_names), scope_(scope) { - quic_ssl_info_ = std::make_shared(*this); -} + quic_stat_names_(quic_stat_names), scope_(scope) {} EnvoyQuicClientSession::~EnvoyQuicClientSession() { ASSERT(!connection()->connected()); diff --git a/source/common/quic/envoy_quic_server_session.cc b/source/common/quic/envoy_quic_server_session.cc index 869a92a82e4b..2f83f9b3201b 100644 --- a/source/common/quic/envoy_quic_server_session.cc +++ b/source/common/quic/envoy_quic_server_session.cc @@ -22,10 +22,10 @@ EnvoyQuicServerSession::EnvoyQuicServerSession( : quic::QuicServerSessionBase(config, supported_versions, connection.get(), visitor, helper, crypto_config, compressed_certs_cache), QuicFilterManagerConnectionImpl(*connection, connection->connection_id(), dispatcher, - send_buffer_limit), + send_buffer_limit, + std::make_shared(*this)), quic_connection_(std::move(connection)), quic_stat_names_(quic_stat_names), listener_scope_(listener_scope), crypto_server_stream_factory_(crypto_server_stream_factory) { - quic_ssl_info_ = std::make_shared(*this); } EnvoyQuicServerSession::~EnvoyQuicServerSession() { diff --git a/source/common/quic/quic_filter_manager_connection_impl.cc b/source/common/quic/quic_filter_manager_connection_impl.cc index 168ddcd6d584..cb011217222f 100644 --- a/source/common/quic/quic_filter_manager_connection_impl.cc +++ b/source/common/quic/quic_filter_manager_connection_impl.cc @@ -10,15 +10,18 @@ namespace Quic { QuicFilterManagerConnectionImpl::QuicFilterManagerConnectionImpl( QuicNetworkConnection& connection, const quic::QuicConnectionId& connection_id, - Event::Dispatcher& dispatcher, uint32_t send_buffer_limit) + Event::Dispatcher& dispatcher, uint32_t send_buffer_limit, + std::shared_ptr info) // Using this for purpose other than logging is not safe. Because QUIC connection id can be // 18 bytes, so there might be collision when it's hashed to 8 bytes. : Network::ConnectionImplBase(dispatcher, /*id=*/connection_id.Hash()), - network_connection_(&connection), + network_connection_(&connection), quic_ssl_info_(info), filter_manager_( std::make_unique(*this, *connection.connectionSocket())), - stream_info_(dispatcher.timeSource(), - connection.connectionSocket()->connectionInfoProviderSharedPtr()), + info_provider_(std::make_shared( + network_connection_->connectionSocket()->connectionInfoProvider(), + network_connection_->connectionSocket()->connectionInfoProviderSharedPtr(), ssl())), + stream_info_(dispatcher.timeSource(), info_provider_), write_buffer_watermark_simulation_( send_buffer_limit / 2, send_buffer_limit, [this]() { onSendBufferLowWatermark(); }, [this]() { onSendBufferHighWatermark(); }, ENVOY_LOGGER()) { diff --git a/source/common/quic/quic_filter_manager_connection_impl.h b/source/common/quic/quic_filter_manager_connection_impl.h index 7b2138d5e1ee..dfcd7acf0d05 100644 --- a/source/common/quic/quic_filter_manager_connection_impl.h +++ b/source/common/quic/quic_filter_manager_connection_impl.h @@ -31,7 +31,8 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase, public: QuicFilterManagerConnectionImpl(QuicNetworkConnection& connection, const quic::QuicConnectionId& connection_id, - Event::Dispatcher& dispatcher, uint32_t send_buffer_limit); + Event::Dispatcher& dispatcher, uint32_t send_buffer_limit, + std::shared_ptr info); // Network::FilterManager // Overridden to delegate calls to filter_manager_. void addWriteFilter(Network::WriteFilterSharedPtr filter) override; @@ -59,11 +60,63 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase, void readDisable(bool /*disable*/) override { ASSERT(false); } void detectEarlyCloseWhenReadDisabled(bool /*value*/) override { ASSERT(false); } bool readEnabled() const override { return true; } + + // TODO(alyssawilk, danzh), sort out which of these need to be handled + // locally. + class ConnectionInfoProviderShim : public Network::ConnectionInfoSetter { + public: + ConnectionInfoProviderShim(Network::ConnectionInfoSetter& setter, + Network::ConnectionInfoProviderSharedPtr conn_info, + Ssl::ConnectionInfoConstSharedPtr ssl) + : setter_(setter), conn_info_(conn_info), ssl_(ssl) {} + + const Network::Address::InstanceConstSharedPtr& localAddress() const override { + return conn_info_->localAddress(); + } + bool localAddressRestored() const override { return conn_info_->localAddressRestored(); } + const Network::Address::InstanceConstSharedPtr& remoteAddress() const override { + return conn_info_->remoteAddress(); + } + const Network::Address::InstanceConstSharedPtr& directRemoteAddress() const override { + return conn_info_->directRemoteAddress(); + } + absl::string_view requestedServerName() const override { + return conn_info_->requestedServerName(); + } + absl::optional connectionID() const override { return conn_info_->connectionID(); } + void dumpState(std::ostream& os, int indent_level) const override { + conn_info_->dumpState(os, indent_level); + } + Ssl::ConnectionInfoConstSharedPtr sslConnection() const override { return ssl_; } + + void setLocalAddress(const Network::Address::InstanceConstSharedPtr& local_address) override { + setter_.setLocalAddress(local_address); + } + void + restoreLocalAddress(const Network::Address::InstanceConstSharedPtr& local_address) override { + setter_.restoreLocalAddress(local_address); + } + void setRemoteAddress(const Network::Address::InstanceConstSharedPtr& remote_address) override { + setter_.setRemoteAddress(remote_address); + } + void setRequestedServerName(const absl::string_view requested_server_name) override { + setter_.setRequestedServerName(requested_server_name); + } + void setConnectionID(uint64_t id) override { setter_.setConnectionID(id); } + void setSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) override { + ssl_ = ssl_connection_info; + } + + Network::ConnectionInfoSetter& setter_; + Network::ConnectionInfoProviderSharedPtr conn_info_; + Ssl::ConnectionInfoConstSharedPtr ssl_; + }; + const Network::ConnectionInfoSetter& connectionInfoProvider() const override { - return network_connection_->connectionSocket()->connectionInfoProvider(); + return *info_provider_; } Network::ConnectionInfoProviderSharedPtr connectionInfoProviderSharedPtr() const override { - return network_connection_->connectionSocket()->connectionInfoProviderSharedPtr(); + return info_provider_; } absl::optional unixSocketPeerCredentials() const override { @@ -177,6 +230,7 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase, // and the rest incoming data bypasses these filters. std::unique_ptr filter_manager_; + std::shared_ptr info_provider_; StreamInfo::StreamInfoImpl stream_info_; std::string transport_failure_reason_; uint32_t bytes_to_send_{0}; diff --git a/source/common/router/router.cc b/source/common/router/router.cc index d9fe690928ef..054b6a6858ec 100644 --- a/source/common/router/router.cc +++ b/source/common/router/router.cc @@ -1406,17 +1406,14 @@ void Filter::onUpstreamHeaders(uint64_t response_code, Http::ResponseHeaderMapPt downstream_response_started_ = true; final_upstream_request_ = &upstream_request; - // In upstream request hedging scenarios properties set in onPoolReady might not be the same - // as the properties set on the final tream, thus we reset fields based on the - // final upstream request. + // In upstream request hedging scenarios the upstream connection ID set in onPoolReady might not + // be the connection ID of the upstream connection that ended up receiving upstream headers. Thus + // reset the upstream connection ID here with the ID of the connection that ultimately was the + // transport for the final upstream request. if (final_upstream_request_->streamInfo().upstreamConnectionId().has_value()) { callbacks_->streamInfo().setUpstreamConnectionId( final_upstream_request_->streamInfo().upstreamConnectionId().value()); } - if (final_upstream_request_->streamInfo().protocol().has_value()) { - callbacks_->streamInfo().upstreamProtocol( - final_upstream_request_->streamInfo().protocol().value()); - } resetOtherUpstreams(upstream_request); if (end_stream) { onUpstreamComplete(upstream_request); diff --git a/source/common/router/upstream_request.cc b/source/common/router/upstream_request.cc index e6dd11faff64..34ccc3f38c5f 100644 --- a/source/common/router/upstream_request.cc +++ b/source/common/router/upstream_request.cc @@ -425,7 +425,6 @@ void UpstreamRequest::onPoolReady( if (protocol) { stream_info_.protocol(protocol.value()); - stream_info_.upstreamProtocol(protocol.value()); } stream_info_.setUpstreamFilterState(std::make_shared( diff --git a/source/common/stream_info/stream_info_impl.h b/source/common/stream_info/stream_info_impl.h index 34323cfb2eed..403ab403154b 100644 --- a/source/common/stream_info/stream_info_impl.h +++ b/source/common/stream_info/stream_info_impl.h @@ -139,10 +139,6 @@ struct StreamInfoImpl : public StreamInfo { void protocol(Http::Protocol protocol) override { protocol_ = protocol; } - absl::optional upstreamProtocol() const override { return upstream_protocol_; } - - void upstreamProtocol(Http::Protocol protocol) override { upstream_protocol_ = protocol; } - absl::optional responseCode() const override { return response_code_; } const absl::optional& responseCodeDetails() const override { @@ -323,7 +319,6 @@ struct StreamInfoImpl : public StreamInfo { absl::optional final_time_; absl::optional protocol_; - absl::optional upstream_protocol_; absl::optional response_code_; absl::optional response_code_details_; absl::optional connection_termination_details_; diff --git a/source/extensions/transport_sockets/tls/connection_info_impl_base.cc b/source/extensions/transport_sockets/tls/connection_info_impl_base.cc index de692e42fff4..3aa9974ff8b9 100644 --- a/source/extensions/transport_sockets/tls/connection_info_impl_base.cc +++ b/source/extensions/transport_sockets/tls/connection_info_impl_base.cc @@ -190,6 +190,18 @@ const std::string& ConnectionInfoImplBase::tlsVersion() const { return cached_tls_version_; } +const std::string& ConnectionInfoImplBase::alpn() const { + if (alpn_.empty()) { + const unsigned char* proto; + unsigned int proto_len; + SSL_get0_alpn_selected(ssl(), &proto, &proto_len); + if (proto != nullptr) { + alpn_ = std::string(reinterpret_cast(proto), proto_len); + } + } + return alpn_; +} + const std::string& ConnectionInfoImplBase::serialNumberPeerCertificate() const { if (!cached_serial_number_peer_certificate_.empty()) { return cached_serial_number_peer_certificate_; diff --git a/source/extensions/transport_sockets/tls/connection_info_impl_base.h b/source/extensions/transport_sockets/tls/connection_info_impl_base.h index b591b4733f10..f5bfa73b0ee1 100644 --- a/source/extensions/transport_sockets/tls/connection_info_impl_base.h +++ b/source/extensions/transport_sockets/tls/connection_info_impl_base.h @@ -38,6 +38,7 @@ class ConnectionInfoImplBase : public Ssl::ConnectionInfo { uint16_t ciphersuiteId() const override; std::string ciphersuiteString() const override; const std::string& tlsVersion() const override; + const std::string& alpn() const override; virtual SSL* ssl() const PURE; @@ -56,6 +57,7 @@ class ConnectionInfoImplBase : public Ssl::ConnectionInfo { mutable std::vector cached_dns_san_local_certificate_; mutable std::string cached_session_id_; mutable std::string cached_tls_version_; + mutable std::string alpn_; }; } // namespace Tls diff --git a/test/common/stream_info/stream_info_impl_test.cc b/test/common/stream_info/stream_info_impl_test.cc index d329599b7176..cfe761f58f3a 100644 --- a/test/common/stream_info/stream_info_impl_test.cc +++ b/test/common/stream_info/stream_info_impl_test.cc @@ -142,9 +142,6 @@ TEST_F(StreamInfoImplTest, MiscSettersAndGetters) { stream_info.protocol(Http::Protocol::Http10); EXPECT_EQ(Http::Protocol::Http10, stream_info.protocol().value()); - stream_info.upstreamProtocol(Http::Protocol::Http11); - EXPECT_EQ(Http::Protocol::Http11, stream_info.upstreamProtocol().value()); - EXPECT_FALSE(stream_info.responseCode()); stream_info.response_code_ = 200; ASSERT_TRUE(stream_info.responseCode()); diff --git a/test/common/stream_info/test_util.h b/test/common/stream_info/test_util.h index 5f8ba0a9a94c..2cebca68cd1a 100644 --- a/test/common/stream_info/test_util.h +++ b/test/common/stream_info/test_util.h @@ -35,8 +35,6 @@ class TestStreamInfo : public StreamInfo::StreamInfo { uint64_t bytesReceived() const override { return 1; } absl::optional protocol() const override { return protocol_; } void protocol(Http::Protocol protocol) override { protocol_ = protocol; } - absl::optional upstreamProtocol() const override { return upstream_protocol_; } - void upstreamProtocol(Http::Protocol protocol) override { upstream_protocol_ = protocol; } absl::optional responseCode() const override { return response_code_; } const absl::optional& responseCodeDetails() const override { return response_code_details_; @@ -253,7 +251,6 @@ class TestStreamInfo : public StreamInfo::StreamInfo { absl::optional end_time_; absl::optional protocol_{Http::Protocol::Http11}; - absl::optional upstream_protocol_{Http::Protocol::Http11}; absl::optional response_code_; absl::optional response_code_details_; absl::optional connection_termination_details_; diff --git a/test/integration/BUILD b/test/integration/BUILD index 7861e4f63a6c..91b850de5bff 100644 --- a/test/integration/BUILD +++ b/test/integration/BUILD @@ -548,6 +548,7 @@ envoy_cc_test( "//source/extensions/filters/http/buffer:config", "//test/integration/filters:encoder_decoder_buffer_filter_lib", "//test/integration/filters:random_pause_filter_lib", + "//test/integration/filters:stream_info_to_headers_filter_lib", "//test/test_common:utility_lib", "@envoy_api//envoy/config/bootstrap/v3:pkg_cc_proto", "@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg_cc_proto", diff --git a/test/integration/filters/BUILD b/test/integration/filters/BUILD index 95d9735eadbd..c63bdcb1b98a 100644 --- a/test/integration/filters/BUILD +++ b/test/integration/filters/BUILD @@ -634,3 +634,18 @@ envoy_cc_test_library( "//test/extensions/filters/http/common:empty_http_filter_config_lib", ], ) + +envoy_cc_test_library( + name = "stream_info_to_headers_filter_lib", + srcs = [ + "stream_info_to_headers_filter.cc", + ], + deps = [ + ":common_lib", + "//envoy/http:filter_interface", + "//envoy/registry", + "//envoy/server:filter_config_interface", + "//source/extensions/filters/http/common:pass_through_filter_lib", + "//test/extensions/filters/http/common:empty_http_filter_config_lib", + ], +) diff --git a/test/integration/filters/stream_info_to_headers_filter.cc b/test/integration/filters/stream_info_to_headers_filter.cc new file mode 100644 index 000000000000..30aedfe2b621 --- /dev/null +++ b/test/integration/filters/stream_info_to_headers_filter.cc @@ -0,0 +1,36 @@ +#include "envoy/registry/registry.h" +#include "envoy/server/filter_config.h" + +#include "source/extensions/filters/http/common/pass_through_filter.h" + +#include "test/extensions/filters/http/common/empty_http_filter_config.h" +#include "test/integration/filters/common.h" + +#include "gtest/gtest.h" + +namespace Envoy { + +// A filter that sticks stream info into headers for integration testing. +class StreamInfoToHeadersFilter : public Http::PassThroughFilter { +public: + constexpr static char name[] = "stream-info-to-headers-filter"; + + Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap&, bool) override { + return Http::FilterHeadersStatus::Continue; + } + + Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap& headers, bool) override { + if (decoder_callbacks_->streamInfo().upstreamSslConnection()) { + headers.addCopy(Http::LowerCaseString("alpn"), + decoder_callbacks_->streamInfo().upstreamSslConnection()->alpn()); + } + return Http::FilterHeadersStatus::Continue; + } +}; + +constexpr char StreamInfoToHeadersFilter::name[]; +static Registry::RegisterFactory, + Server::Configuration::NamedHttpFilterConfigFactory> + register_; + +} // namespace Envoy diff --git a/test/integration/multiplexed_upstream_integration_test.cc b/test/integration/multiplexed_upstream_integration_test.cc index 4acf747d777c..92c8f427f535 100644 --- a/test/integration/multiplexed_upstream_integration_test.cc +++ b/test/integration/multiplexed_upstream_integration_test.cc @@ -92,6 +92,11 @@ TEST_P(Http2UpstreamIntegrationTest, TestSchemeAndXFP) { // Ensure Envoy handles streaming requests and responses simultaneously. void Http2UpstreamIntegrationTest::bidirectionalStreaming(uint32_t bytes) { + config_helper_.prependFilter(fmt::format(R"EOF( + name: stream-info-to-headers-filter + typed_config: + "@type": type.googleapis.com/google.protobuf.Empty)EOF")); + initialize(); codec_client_ = makeHttpConnection(lookupPort("http")); @@ -124,6 +129,10 @@ void Http2UpstreamIntegrationTest::bidirectionalStreaming(uint32_t bytes) { upstream_request_->encodeTrailers(Http::TestResponseTrailerMapImpl{{"trailer", "bar"}}); ASSERT_TRUE(response->waitForEndStream()); EXPECT_TRUE(response->complete()); + std::string expected_alpn = upstreamProtocol() == Http::CodecType::HTTP2 ? "h2" : "h3"; + ASSERT_FALSE(response->headers().get(Http::LowerCaseString("alpn")).empty()); + ASSERT_EQ(response->headers().get(Http::LowerCaseString("alpn"))[0]->value().getStringView(), + expected_alpn); } TEST_P(Http2UpstreamIntegrationTest, BidirectionalStreaming) { bidirectionalStreaming(1024); } diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 43ad275fce27..519558f51898 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -62,6 +62,7 @@ class MockConnectionInfo : public ConnectionInfo { MOCK_METHOD(uint16_t, ciphersuiteId, (), (const)); MOCK_METHOD(std::string, ciphersuiteString, (), (const)); MOCK_METHOD(const std::string&, tlsVersion, (), (const)); + MOCK_METHOD(const std::string&, alpn, (), (const)); }; class MockClientContext : public ClientContext { diff --git a/test/mocks/stream_info/mocks.h b/test/mocks/stream_info/mocks.h index cd9a7b70d113..be4c47a77ead 100644 --- a/test/mocks/stream_info/mocks.h +++ b/test/mocks/stream_info/mocks.h @@ -54,8 +54,6 @@ class MockStreamInfo : public StreamInfo { MOCK_METHOD(const std::string&, getRouteName, (), (const)); MOCK_METHOD(absl::optional, protocol, (), (const)); MOCK_METHOD(void, protocol, (Http::Protocol protocol)); - MOCK_METHOD(absl::optional, upstreamProtocol, (), (const)); - MOCK_METHOD(void, upstreamProtocol, (Http::Protocol protocol)); MOCK_METHOD(absl::optional, responseCode, (), (const)); MOCK_METHOD(const absl::optional&, responseCodeDetails, (), (const)); MOCK_METHOD(const absl::optional&, connectionTerminationDetails, (), (const));