Skip to content

Commit

Permalink
tls: move ssl connection info into SocketAddressProvider (envoyproxy#…
Browse files Browse the repository at this point in the history
…17334)

Part of envoyproxy#17168

Signed-off-by: He Jie Xu <hejie.xu@intel.com>
  • Loading branch information
soulxu authored and Le Yao committed Sep 30, 2021
1 parent 7b7cde3 commit 18e7ce1
Show file tree
Hide file tree
Showing 37 changed files with 191 additions and 214 deletions.
1 change: 1 addition & 0 deletions envoy/network/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ envoy_cc_library(
deps = [
":address_interface",
":io_handle_interface",
"//envoy/ssl:connection_interface",
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)
Expand Down
12 changes: 12 additions & 0 deletions envoy/network/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "envoy/config/core/v3/base.pb.h"
#include "envoy/network/address.h"
#include "envoy/network/io_handle.h"
#include "envoy/ssl/connection.h"

#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
Expand Down Expand Up @@ -92,6 +93,12 @@ class SocketAddressProvider {
* @param indent_level the level of indentation.
*/
virtual void dumpState(std::ostream& os, int indent_level) const PURE;

/**
* @return the downstream SSL connection. This will be nullptr if the downstream
* connection does not use SSL.
*/
virtual Ssl::ConnectionInfoConstSharedPtr sslConnection() const PURE;
};

class SocketAddressSetter : public SocketAddressProvider {
Expand Down Expand Up @@ -131,6 +138,11 @@ class SocketAddressSetter : public SocketAddressProvider {
* @param id Connection ID of the downstream connection.
**/
virtual void setConnectionID(uint64_t id) PURE;

/**
* @param connection_info sets the downstream ssl connection.
*/
virtual void setSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) PURE;
};

using SocketAddressSetterSharedPtr = std::shared_ptr<SocketAddressSetter>;
Expand Down
12 changes: 0 additions & 12 deletions envoy/stream_info/stream_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,18 +468,6 @@ class StreamInfo {
*/
virtual const Network::SocketAddressProvider& downstreamAddressProvider() const PURE;

/**
* @param connection_info sets the downstream ssl connection.
*/
virtual void
setDownstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) PURE;

/**
* @return the downstream SSL connection. This will be nullptr if the downstream
* connection does not use SSL.
*/
virtual Ssl::ConnectionInfoConstSharedPtr downstreamSslConnection() const PURE;

/**
* @param connection_info sets the upstream ssl connection.
*/
Expand Down
14 changes: 8 additions & 6 deletions source/common/formatter/substitution_formatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,11 @@ class StreamInfoSslConnectionInfoFieldExtractor : public StreamInfoFormatter::Fi
StreamInfoSslConnectionInfoFieldExtractor(FieldExtractor f) : field_extractor_(f) {}

absl::optional<std::string> extract(const StreamInfo::StreamInfo& stream_info) const override {
if (stream_info.downstreamSslConnection() == nullptr) {
if (stream_info.downstreamAddressProvider().sslConnection() == nullptr) {
return absl::nullopt;
}

const auto value = field_extractor_(*stream_info.downstreamSslConnection());
const auto value = field_extractor_(*stream_info.downstreamAddressProvider().sslConnection());
if (value && value->empty()) {
return absl::nullopt;
}
Expand All @@ -668,11 +668,11 @@ class StreamInfoSslConnectionInfoFieldExtractor : public StreamInfoFormatter::Fi
}

ProtobufWkt::Value extractValue(const StreamInfo::StreamInfo& stream_info) const override {
if (stream_info.downstreamSslConnection() == nullptr) {
if (stream_info.downstreamAddressProvider().sslConnection() == nullptr) {
return unspecifiedValue();
}

const auto value = field_extractor_(*stream_info.downstreamSslConnection());
const auto value = field_extractor_(*stream_info.downstreamAddressProvider().sslConnection());
if (value && value->empty()) {
return unspecifiedValue();
}
Expand Down Expand Up @@ -1335,7 +1335,8 @@ DownstreamPeerCertVStartFormatter::DownstreamPeerCertVStartFormatter(const std::
parseFormat(token, sizeof("DOWNSTREAM_PEER_CERT_V_START(") - 1),
std::make_unique<SystemTimeFormatter::TimeFieldExtractor>(
[](const StreamInfo::StreamInfo& stream_info) -> absl::optional<SystemTime> {
const auto connection_info = stream_info.downstreamSslConnection();
const auto connection_info =
stream_info.downstreamAddressProvider().sslConnection();
return connection_info != nullptr ? connection_info->validFromPeerCertificate()
: absl::optional<SystemTime>();
})) {}
Expand All @@ -1347,7 +1348,8 @@ DownstreamPeerCertVEndFormatter::DownstreamPeerCertVEndFormatter(const std::stri
parseFormat(token, sizeof("DOWNSTREAM_PEER_CERT_V_END(") - 1),
std::make_unique<SystemTimeFormatter::TimeFieldExtractor>(
[](const StreamInfo::StreamInfo& stream_info) -> absl::optional<SystemTime> {
const auto connection_info = stream_info.downstreamSslConnection();
const auto connection_info =
stream_info.downstreamAddressProvider().sslConnection();
return connection_info != nullptr ? connection_info->expirationPeerCertificate()
: absl::optional<SystemTime>();
})) {}
Expand Down
1 change: 0 additions & 1 deletion source/common/http/codec_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ RequestEncoder& CodecClient::newStream(ResponseDecoder& response_decoder) {
void CodecClient::onEvent(Network::ConnectionEvent event) {
if (event == Network::ConnectionEvent::Connected) {
ENVOY_CONN_LOG(debug, "connected", *connection_);
connection_->streamInfo().setDownstreamSslConnection(connection_->ssl());
connected_ = true;
}

Expand Down
3 changes: 0 additions & 3 deletions source/common/http/conn_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,6 @@ ConnectionManagerImpl::ActiveStream::ActiveStream(ConnectionManagerImpl& connect
connection_manager_.stats_.named_.downstream_rq_http1_total_.inc();
}

filter_manager_.streamInfo().setDownstreamSslConnection(
connection_manager_.read_callbacks_->connection().ssl());

if (connection_manager_.config_.streamIdleTimeout().count()) {
idle_timeout_ms_ = connection_manager_.config_.streamIdleTimeout();
stream_idle_timer_ =
Expand Down
6 changes: 6 additions & 0 deletions source/common/http/filter_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,12 @@ class OverridableRemoteSocketAddressSetterStreamInfo : public StreamInfo::Stream
absl::optional<uint64_t> connectionID() const override {
return StreamInfoImpl::downstreamAddressProvider().connectionID();
}
Ssl::ConnectionInfoConstSharedPtr sslConnection() const override {
return StreamInfoImpl::downstreamAddressProvider().sslConnection();
}
Ssl::ConnectionInfoConstSharedPtr upstreamSslConnection() const override {
return StreamInfoImpl::upstreamSslConnection();
}
void dumpState(std::ostream& os, int indent_level) const override {
StreamInfoImpl::dumpState(os, indent_level);

Expand Down
1 change: 1 addition & 0 deletions source/common/network/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, ConnectionSocketPt
// TODO(soulxu): generate the connection id inside the addressProvider directly,
// then we don't need a setter or any of the optional stuff.
socket_->addressProvider().setConnectionID(id());
socket_->addressProvider().setSslConnection(transport_socket_->ssl());
}

ConnectionImpl::~ConnectionImpl() {
Expand Down
5 changes: 5 additions & 0 deletions source/common/network/socket_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class SocketAddressSetterImpl : public SocketAddressSetter {
}
absl::optional<uint64_t> connectionID() const override { return connection_id_; }
void setConnectionID(uint64_t id) override { connection_id_ = id; }
Ssl::ConnectionInfoConstSharedPtr sslConnection() const override { return ssl_info_; }
void setSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) override {
ssl_info_ = ssl_connection_info;
}

private:
Address::InstanceConstSharedPtr local_address_;
Expand All @@ -59,6 +63,7 @@ class SocketAddressSetterImpl : public SocketAddressSetter {
Address::InstanceConstSharedPtr direct_remote_address_;
std::string server_name_;
absl::optional<uint64_t> connection_id_;
Ssl::ConnectionInfoConstSharedPtr ssl_info_;
};

class SocketImpl : public virtual Socket {
Expand Down
10 changes: 6 additions & 4 deletions source/common/router/config_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,16 @@ bool RouteEntryImplBase::evaluateTlsContextMatch(const StreamInfo::StreamInfo& s
const TlsContextMatchCriteria& criteria = *tlsContextMatchCriteria();

if (criteria.presented().has_value()) {
const bool peer_presented = stream_info.downstreamSslConnection() &&
stream_info.downstreamSslConnection()->peerCertificatePresented();
const bool peer_presented =
stream_info.downstreamAddressProvider().sslConnection() &&
stream_info.downstreamAddressProvider().sslConnection()->peerCertificatePresented();
matches &= criteria.presented().value() == peer_presented;
}

if (criteria.validated().has_value()) {
const bool peer_validated = stream_info.downstreamSslConnection() &&
stream_info.downstreamSslConnection()->peerCertificateValidated();
const bool peer_validated =
stream_info.downstreamAddressProvider().sslConnection() &&
stream_info.downstreamAddressProvider().sslConnection()->peerCertificateValidated();
matches &= criteria.validated().value() == peer_validated;
}

Expand Down
4 changes: 2 additions & 2 deletions source/common/router/header_formatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ parseRequestHeader(absl::string_view param) {
StreamInfoHeaderFormatter::FieldExtractor sslConnectionInfoStringHeaderExtractor(
std::function<std::string(const Ssl::ConnectionInfo& connection_info)> string_extractor) {
return [string_extractor](const StreamInfo::StreamInfo& stream_info) {
if (stream_info.downstreamSslConnection() == nullptr) {
if (stream_info.downstreamAddressProvider().sslConnection() == nullptr) {
return std::string();
}

return string_extractor(*stream_info.downstreamSslConnection());
return string_extractor(*stream_info.downstreamAddressProvider().sslConnection());
};
}

Expand Down
6 changes: 3 additions & 3 deletions source/common/router/router.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,9 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers,

route_entry_->finalizeRequestHeaders(headers, callbacks_->streamInfo(),
!config_.suppress_envoy_headers_);
FilterUtility::setUpstreamScheme(headers,
callbacks_->streamInfo().downstreamSslConnection() != nullptr,
host->transportSocketFactory().implementsSecureTransport());
FilterUtility::setUpstreamScheme(
headers, callbacks_->streamInfo().downstreamAddressProvider().sslConnection() != nullptr,
host->transportSocketFactory().implementsSecureTransport());

// Ensure an http transport scheme is selected before continuing with decoding.
ASSERT(headers.Scheme());
Expand Down
5 changes: 3 additions & 2 deletions source/common/router/upstream_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,9 @@ void UpstreamRequest::onPoolReady(
stream_info_.setUpstreamLocalAddress(upstream_local_address);
parent_.callbacks()->streamInfo().setUpstreamLocalAddress(upstream_local_address);

stream_info_.setUpstreamSslConnection(info.downstreamSslConnection());
parent_.callbacks()->streamInfo().setUpstreamSslConnection(info.downstreamSslConnection());
stream_info_.setUpstreamSslConnection(info.downstreamAddressProvider().sslConnection());
parent_.callbacks()->streamInfo().setUpstreamSslConnection(
info.downstreamAddressProvider().sslConnection());

if (parent_.downstreamEndStream()) {
setupPerTryTimeout();
Expand Down
10 changes: 0 additions & 10 deletions source/common/stream_info/stream_info_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,6 @@ struct StreamInfoImpl : public StreamInfo {
return *downstream_address_provider_;
}

void
setDownstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override {
downstream_ssl_info_ = connection_info;
}

Ssl::ConnectionInfoConstSharedPtr downstreamSslConnection() const override {
return downstream_ssl_info_;
}

void setUpstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override {
upstream_ssl_info_ = connection_info;
}
Expand Down Expand Up @@ -324,7 +315,6 @@ struct StreamInfoImpl : public StreamInfo {
uint64_t bytes_sent_{};
Network::Address::InstanceConstSharedPtr upstream_local_address_;
const Network::SocketAddressProviderSharedPtr downstream_address_provider_;
Ssl::ConnectionInfoConstSharedPtr downstream_ssl_info_;
Ssl::ConnectionInfoConstSharedPtr upstream_ssl_info_;
std::string requested_server_name_;
const Http::RequestHeaderMap* request_headers_{};
Expand Down
1 change: 0 additions & 1 deletion source/common/tcp/conn_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ void ActiveTcpClient::onEvent(Network::ConnectionEvent event) {
// This is also necessary for prefetch to be used with such protocols.
if (event == Network::ConnectionEvent::Connected) {
connection_->readDisable(true);
connection_->streamInfo().setDownstreamSslConnection(connection_->ssl());
}
Envoy::ConnectionPool::ActiveClient::onEvent(event);
if (callbacks_) {
Expand Down
1 change: 0 additions & 1 deletion source/common/tcp/original_conn_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ void OriginalConnPoolImpl::onConnectionEvent(ActiveConn& conn, Network::Connecti
// whether the connection is in the ready list (connected) or the pending list (failed to
// connect).
if (event == Network::ConnectionEvent::Connected) {
conn.conn_->streamInfo().setDownstreamSslConnection(conn.conn_->ssl());
conn_connect_ms_->complete();
processIdleConnection(conn, true, false);
}
Expand Down
11 changes: 6 additions & 5 deletions source/common/tcp_proxy/upstream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ void TcpConnPool::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn_data
Network::Connection& connection = conn_data->connection();

auto upstream = std::make_unique<TcpUpstream>(std::move(conn_data), upstream_callbacks_);
callbacks_->onGenericPoolReady(&connection.streamInfo(), std::move(upstream), host,
latched_data->connection().addressProvider().localAddress(),
latched_data->connection().streamInfo().downstreamSslConnection());
callbacks_->onGenericPoolReady(
&connection.streamInfo(), std::move(upstream), host,
latched_data->connection().addressProvider().localAddress(),
latched_data->connection().streamInfo().downstreamAddressProvider().sslConnection());
}

HttpConnPool::HttpConnPool(Upstream::ThreadLocalCluster& thread_local_cluster,
Expand Down Expand Up @@ -233,8 +234,8 @@ void HttpConnPool::onPoolReady(Http::RequestEncoder& request_encoder,
upstream_handle_ = nullptr;
upstream_->setRequestEncoder(request_encoder,
host->transportSocketFactory().implementsSecureTransport());
upstream_->setConnPoolCallbacks(
std::make_unique<HttpConnPool::Callbacks>(*this, host, info.downstreamSslConnection()));
upstream_->setConnPoolCallbacks(std::make_unique<HttpConnPool::Callbacks>(
*this, host, info.downstreamAddressProvider().sslConnection()));
}

void HttpConnPool::onGenericPoolReady(Upstream::HostDescriptionConstSharedPtr& host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ void Utility::extractCommonAccessLogProperties(
*stream_info.downstreamAddressProvider().localAddress(),
*common_access_log.mutable_downstream_local_address());
}
if (stream_info.downstreamSslConnection() != nullptr) {
if (stream_info.downstreamAddressProvider().sslConnection() != nullptr) {
auto* tls_properties = common_access_log.mutable_tls_properties();
const Ssl::ConnectionInfoConstSharedPtr downstream_ssl_connection =
stream_info.downstreamSslConnection();
stream_info.downstreamAddressProvider().sslConnection();

tls_properties->set_tls_sni_hostname(
std::string(stream_info.downstreamAddressProvider().requestedServerName()));
Expand Down
7 changes: 4 additions & 3 deletions source/extensions/filters/common/expr/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ absl::optional<CelValue> ConnectionWrapper::operator[](CelValue key) const {
}
auto value = key.StringOrDie().value();
if (value == MTLS) {
return CelValue::CreateBool(info_.downstreamSslConnection() != nullptr &&
info_.downstreamSslConnection()->peerCertificatePresented());
return CelValue::CreateBool(
info_.downstreamAddressProvider().sslConnection() != nullptr &&
info_.downstreamAddressProvider().sslConnection()->peerCertificatePresented());
} else if (value == RequestedServerName) {
return CelValue::CreateStringView(info_.downstreamAddressProvider().requestedServerName());
} else if (value == ID) {
Expand All @@ -198,7 +199,7 @@ absl::optional<CelValue> ConnectionWrapper::operator[](CelValue key) const {
return {};
}

auto ssl_info = info_.downstreamSslConnection();
auto ssl_info = info_.downstreamAddressProvider().sslConnection();
if (ssl_info != nullptr) {
return extractSslInfo(*ssl_info, value);
}
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/filters/http/lua/wrappers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ int StreamInfoWrapper::luaDynamicMetadata(lua_State* state) {
}

int StreamInfoWrapper::luaDownstreamSslConnection(lua_State* state) {
const auto& ssl = stream_info_.downstreamSslConnection();
const auto& ssl = stream_info_.downstreamAddressProvider().sslConnection();
if (ssl != nullptr) {
if (downstream_ssl_connection_.get() != nullptr) {
downstream_ssl_connection_.pushStack();
Expand Down
1 change: 0 additions & 1 deletion source/server/active_stream_listener_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ void ActiveStreamListenerBase::newConnection(Network::ConnectionSocketPtr&& sock
}
stream_info->setFilterChainName(filter_chain->name());
auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr);
stream_info->setDownstreamSslConnection(transport_socket->ssl());
auto server_conn_ptr = dispatcher().createServerConnection(
std::move(socket), std::move(transport_socket), *stream_info);
if (const auto timeout = filter_chain->transportSocketConnectTimeout();
Expand Down
Loading

0 comments on commit 18e7ce1

Please sign in to comment.