From 3df01498989309db720432aaa67899468c427e5e Mon Sep 17 00:00:00 2001 From: Sergei Politov Date: Mon, 26 Jul 2021 18:09:27 +0300 Subject: [PATCH] [#9370] Implement network traffic compression Summary: This diff implements support for network traffic compression. There are 2 flags to configure it: enable_stream_compression - whether we enable compression at all. stream_compression_algo - algorithm index that should be used for compression: 0 - no compression 1 - gzip It should be safe to enable compression and set the algorithm to 0. But since this feature is pretty new, we fully disable compression by default. Introduced a StreamRefiner for refined streams, ie; encryption / compression. The following compression related work should be done in follow-up diffs: 1) Add tests for encryption+compression. 2) Add more compression algorithms. 3) Change `StreamRefiner` interface to avoid the extra copy of decompressed data. Test Plan: ybd --gtest_filter CompressedStreamTest.* ybd --gtest_filter TestRpcCompression.* Reviewers: bogdan Reviewed By: bogdan Subscribers: sanketh, ybase Differential Revision: https://phabricator.dev.yugabyte.com/D12328 --- ent/src/yb/server/secure.cc | 27 +- src/yb/integration-tests/CMakeLists.txt | 1 + .../compressed_stream-test.cc | 123 ++++ src/yb/rpc/CMakeLists.txt | 2 + src/yb/rpc/compressed_stream.cc | 333 +++++++++++ src/yb/rpc/compressed_stream.h | 31 + src/yb/rpc/outbound_data.h | 37 ++ src/yb/rpc/refined_stream.cc | 301 ++++++++++ src/yb/rpc/refined_stream.h | 129 ++++ src/yb/rpc/rpc-test-base.h | 2 +- src/yb/rpc/rpc-test.cc | 147 ++++- src/yb/rpc/rpc_fwd.h | 1 + src/yb/rpc/secure_stream.cc | 551 +++++------------- src/yb/rpc/secure_stream.h | 5 - src/yb/rpc/stream.h | 11 +- src/yb/rpc/tcp_stream.h | 4 +- src/yb/rpc/yb_rpc.cc | 4 +- 17 files changed, 1264 insertions(+), 445 deletions(-) create mode 100644 src/yb/integration-tests/compressed_stream-test.cc create mode 100644 src/yb/rpc/compressed_stream.cc create mode 100644 src/yb/rpc/compressed_stream.h create mode 100644 src/yb/rpc/refined_stream.cc create mode 100644 src/yb/rpc/refined_stream.h diff --git a/ent/src/yb/server/secure.cc b/ent/src/yb/server/secure.cc index 95001b8f57ab..2c6c73b3c45a 100644 --- a/ent/src/yb/server/secure.cc +++ b/ent/src/yb/server/secure.cc @@ -15,6 +15,7 @@ #include "yb/fs/fs_manager.h" +#include "yb/rpc/compressed_stream.h" #include "yb/rpc/messenger.h" #include "yb/rpc/secure_stream.h" #include "yb/rpc/tcp_stream.h" @@ -53,6 +54,8 @@ DEFINE_string(key_file_pattern, "node.$0.key", "Pattern used for key file"); DEFINE_string(cert_file_pattern, "node.$0.crt", "Pattern used for certificate file"); +DEFINE_bool(enable_stream_compression, false, "Whether it is allowed to use stream compression."); + namespace yb { namespace server { namespace { @@ -87,13 +90,28 @@ Result> SetupSecureContext( return SetupSecureContext(std::string(), root_dir, name, type, builder); } +void ApplyCompressedStream( + rpc::MessengerBuilder* builder, const rpc::StreamFactoryPtr lower_layer_factory) { + if (!FLAGS_enable_stream_compression) { + return; + } + builder->SetListenProtocol(rpc::CompressedStreamProtocol()); + auto parent_mem_tracker = builder->last_used_parent_mem_tracker(); + auto buffer_tracker = MemTracker::FindOrCreateTracker( + -1, "Compressed Read Buffer", parent_mem_tracker); + builder->AddStreamFactory( + rpc::CompressedStreamProtocol(), + rpc::CompressedStreamFactory(std::move(lower_layer_factory), buffer_tracker)); +} + Result> SetupSecureContext( const std::string& cert_dir, const std::string& root_dir, const std::string& name, SecureContextType type, rpc::MessengerBuilder* builder) { auto use = type == SecureContextType::kInternal ? FLAGS_use_node_to_node_encryption : FLAGS_use_client_to_server_encryption; if (!use) { - return std::unique_ptr(); + ApplyCompressedStream(builder, rpc::TcpStream::Factory()); + return nullptr; } std::string dir; @@ -157,10 +175,11 @@ void ApplySecureContext(const rpc::SecureContext* context, rpc::MessengerBuilder auto buffer_tracker = MemTracker::FindOrCreateTracker( -1, "Encrypted Read Buffer", parent_mem_tracker); + auto secure_stream_factory = rpc::SecureStreamFactory( + rpc::TcpStream::Factory(), buffer_tracker, context); builder->SetListenProtocol(rpc::SecureStreamProtocol()); - builder->AddStreamFactory( - rpc::SecureStreamProtocol(), - rpc::SecureStreamFactory(rpc::TcpStream::Factory(), buffer_tracker, context)); + builder->AddStreamFactory(rpc::SecureStreamProtocol(), secure_stream_factory); + ApplyCompressedStream(builder, secure_stream_factory); } } // namespace server diff --git a/src/yb/integration-tests/CMakeLists.txt b/src/yb/integration-tests/CMakeLists.txt index 38a67b8baec7..d86ea48f0ff0 100644 --- a/src/yb/integration-tests/CMakeLists.txt +++ b/src/yb/integration-tests/CMakeLists.txt @@ -97,6 +97,7 @@ ADD_YB_TEST(tablet-split-itest) # Not sure if we really need RUN_SERIAL here as this might not be a resource-intensive test. ADD_YB_TEST(compaction-test) +ADD_YB_TEST(compressed_stream-test) ADD_YB_TEST(logging-test) ADD_YB_TEST(master_replication-itest) ADD_YB_TEST(master_sysnamespace-itest) diff --git a/src/yb/integration-tests/compressed_stream-test.cc b/src/yb/integration-tests/compressed_stream-test.cc new file mode 100644 index 000000000000..4685127a171c --- /dev/null +++ b/src/yb/integration-tests/compressed_stream-test.cc @@ -0,0 +1,123 @@ +// Copyright (c) YugaByte, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#include "yb/client/ql-dml-test-base.h" +#include "yb/client/session.h" +#include "yb/client/table_handle.h" + +#include "yb/common/ql_value.h" + +#include "yb/rpc/compressed_stream.h" +#include "yb/rpc/messenger.h" +#include "yb/rpc/tcp_stream.h" + +#include "yb/server/secure.h" + +#include "yb/util/size_literals.h" +#include "yb/util/env_util.h" + +#include "yb/yql/cql/ql/util/errcodes.h" +#include "yb/yql/cql/ql/util/statement_result.h" + +using namespace std::literals; + +DECLARE_int32(stream_compression_algo); +DECLARE_bool(enable_stream_compression); + +namespace yb { + +class CompressedStreamTest : public client::KeyValueTableTest { + public: + void SetUp() override { + FLAGS_enable_stream_compression = true; + FLAGS_stream_compression_algo = 1; + KeyValueTableTest::SetUp(); + } + + CHECKED_STATUS CreateClient() override { + auto host = "127.0.0.52"; + client_ = VERIFY_RESULT(DoCreateClient(host, host)); + return Status::OK(); + } + + Result> DoCreateClient( + const std::string& name, const std::string& host) { + rpc::MessengerBuilder messenger_builder("test_client"); + messenger_builder.SetListenProtocol(rpc::CompressedStreamProtocol()); + messenger_builder.AddStreamFactory( + rpc::CompressedStreamProtocol(), + CompressedStreamFactory(rpc::TcpStream::Factory(), MemTracker::GetRootTracker())); + auto messenger = VERIFY_RESULT(messenger_builder.Build()); + messenger->TEST_SetOutboundIpBase(VERIFY_RESULT(HostToAddress(host))); + return cluster_->CreateClient(std::move(messenger)); + } + + void TestSimpleOps(); +}; + +void CompressedStreamTest::TestSimpleOps() { + CreateTable(client::Transactional::kFalse); + + const int32_t kKey = 1; + const int32_t kValue = 2; + + { + auto session = NewSession(); + auto op = ASSERT_RESULT(WriteRow(session, kKey, kValue)); + ASSERT_EQ(op->response().status(), QLResponsePB::YQL_STATUS_OK); + } + + { + auto value = ASSERT_RESULT(SelectRow(NewSession(), kKey)); + ASSERT_EQ(kValue, value); + } +} + +TEST_F(CompressedStreamTest, Simple) { + TestSimpleOps(); +} + +TEST_F(CompressedStreamTest, BigWrite) { + client::YBSchemaBuilder builder; + builder.AddColumn(kKeyColumn)->Type(INT32)->HashPrimaryKey()->NotNull(); + builder.AddColumn(kValueColumn)->Type(STRING); + + ASSERT_OK(table_.Create(client::kTableName, 1, client_.get(), &builder)); + + const int32_t kKey = 1; + const std::string kValue(64_KB, 'X'); + + auto session = NewSession(); + { + const auto op = table_.NewWriteOp(QLWriteRequestPB::QL_STMT_INSERT); + auto* const req = op->mutable_request(); + QLAddInt32HashValue(req, kKey); + table_.AddStringColumnValue(req, kValueColumn, kValue); + ASSERT_OK(session->ApplyAndFlush(op)); + ASSERT_OK(CheckOp(op.get())); + } + + { + const auto op = table_.NewReadOp(); + auto* const req = op->mutable_request(); + QLAddInt32HashValue(req, kKey); + table_.AddColumns({kValueColumn}, req); + ASSERT_OK(session->ApplyAndFlush(op)); + ASSERT_OK(CheckOp(op.get())); + auto rowblock = yb::ql::RowsResult(op.get()).GetRowBlock(); + ASSERT_EQ(rowblock->row_count(), 1); + ASSERT_EQ(kValue, rowblock->row(0).column(0).string_value()); + } +} + +} // namespace yb diff --git a/src/yb/rpc/CMakeLists.txt b/src/yb/rpc/CMakeLists.txt index 9e62ba03f703..1d33d03bf7be 100644 --- a/src/yb/rpc/CMakeLists.txt +++ b/src/yb/rpc/CMakeLists.txt @@ -59,6 +59,7 @@ set(YRPC_SRCS acceptor.cc binary_call_parser.cc circular_read_buffer.cc + compressed_stream.cc connection.cc connection_context.cc growable_buffer.cc @@ -72,6 +73,7 @@ set(YRPC_SRCS poller.cc proxy.cc reactor.cc + refined_stream.cc remote_method.cc rpc.cc rpc_context.cc diff --git a/src/yb/rpc/compressed_stream.cc b/src/yb/rpc/compressed_stream.cc new file mode 100644 index 000000000000..c49fc81a608e --- /dev/null +++ b/src/yb/rpc/compressed_stream.cc @@ -0,0 +1,333 @@ +// Copyright (c) YugaByte, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#include "yb/rpc/compressed_stream.h" + +#include + +#include "yb/rpc/circular_read_buffer.h" +#include "yb/rpc/outbound_data.h" +#include "yb/rpc/refined_stream.h" + +#include "yb/util/logging.h" +#include "yb/util/size_literals.h" + +using namespace std::literals; + +DEFINE_int32(stream_compression_algo, 0, "Algorithm used for stream compression. " + "0 - no compression, 1 - gzip."); + +namespace yb { +namespace rpc { + +namespace { + +class Compressor { + public: + virtual std::string ToString() const = 0; + + // Initialize compressor, required since we don't use exceptions to return error from ctor. + virtual CHECKED_STATUS Init() = 0; + + // Compress specified vector of input buffers into single output buffer. + virtual CHECKED_STATUS Compress( + const boost::container::small_vector_base& input, RefCntBuffer* output) = 0; + + // Decompress specified input slice to specified output buffer. + virtual Result Decompress(const Slice& input, StreamReadBuffer* output) = 0; + + // Connection header associated with this compressor. + virtual OutboundDataPtr ConnectionHeader() = 0; + + virtual ~Compressor() = default; +}; + +template +OutboundDataPtr GetConnectionHeader() { + // Compressed stream header has signature YBx, where x - compressor identifier. + static auto result = std::make_shared( + "YB"s + Compressor::kId, Compressor::kId + "ConnectionHeader"s); + return result; +} + +class GZipCompressor : public Compressor { + public: + static const char kId = 'G'; + static const int kIndex = 1; + + GZipCompressor() { + } + + ~GZipCompressor() { + if (deflate_inited_) { + int res = deflateEnd(&deflate_stream_); + LOG_IF(WARNING, res != Z_OK) << "Failed to destroy deflate stream: " << res; + } + if (inflate_inited_) { + int res = inflateEnd(&inflate_stream_); + LOG_IF(WARNING, res != Z_OK) << "Failed to destroy inflate stream: " << res; + } + } + + OutboundDataPtr ConnectionHeader() override { + return GetConnectionHeader(); + } + + CHECKED_STATUS Init() override { + memset(&deflate_stream_, 0, sizeof(deflate_stream_)); + int res = deflateInit(&deflate_stream_, /* level= */ Z_DEFAULT_COMPRESSION); + if (res != Z_OK) { + return STATUS_FORMAT(RuntimeError, "Cannot init deflate stream: $0", res); + } + deflate_inited_ = true; + + memset(&inflate_stream_, 0, sizeof(inflate_stream_)); + res = inflateInit(&inflate_stream_); + if (res != Z_OK) { + return STATUS_FORMAT(RuntimeError, "Cannot init inflate stream: $0", res); + } + inflate_inited_ = true; + + return Status::OK(); + } + + std::string ToString() const override { + return "GZip"; + } + + CHECKED_STATUS Compress( + const boost::container::small_vector_base& input, + RefCntBuffer* output) override { + size_t total_len = 0; + for (const auto& buf : input) { + total_len += buf.size(); + } + *output = RefCntBuffer(deflateBound(&deflate_stream_, total_len)); + deflate_stream_.avail_out = static_cast(output->size()); + deflate_stream_.next_out = output->udata(); + + for (auto it = input.begin(); it != input.end();) { + const auto& buf = *it++; + deflate_stream_.next_in = const_cast(buf.udata()); + deflate_stream_.avail_in = static_cast(buf.size()); + + for (;;) { + auto res = deflate(&deflate_stream_, it == input.end() ? Z_PARTIAL_FLUSH : Z_NO_FLUSH); + if (res == Z_STREAM_END) { + if (deflate_stream_.avail_in != 0) { + return STATUS_FORMAT( + RuntimeError, "Stream end when input data still available: $0", + deflate_stream_.avail_in); + } + break; + } + if (res != Z_OK) { + return STATUS_FORMAT(RuntimeError, "Compression failed: $0", res); + } + if (deflate_stream_.avail_in == 0) { + break; + } + } + } + + output->Shrink(deflate_stream_.next_out - output->udata()); + + return Status::OK(); + } + + Result Decompress(const Slice& input, StreamReadBuffer* output) override { + auto io_vecs = VERIFY_RESULT(output->PrepareAppend()); + size_t total_out = 0; + inflate_stream_.avail_out = 0; + + inflate_stream_.next_in = const_cast(pointer_cast(input.data())); + inflate_stream_.avail_in = input.size(); + + auto next_io_vec_it = io_vecs.begin(); + while (inflate_stream_.avail_in != 0) { + if (inflate_stream_.avail_out == 0) { + if (next_io_vec_it == io_vecs.end()) { + // We don't have space in output buffer. + // So return with what we have, and expect that caller would free some space and + // call decompress again. + break; + } + inflate_stream_.avail_out = next_io_vec_it->iov_len; + inflate_stream_.next_out = static_cast(next_io_vec_it->iov_base); + ++next_io_vec_it; + } + auto old_avail_out = inflate_stream_.avail_out; + int res = inflate(&inflate_stream_, Z_NO_FLUSH); + if (res != Z_OK) { + return STATUS_FORMAT(RuntimeError, "Decompression failed: $0", res); + } + total_out += old_avail_out - inflate_stream_.avail_out; + } + + output->DataAppended(total_out); + + return input.size() - inflate_stream_.avail_in; + } + + private: + z_stream deflate_stream_; + z_stream inflate_stream_; + bool deflate_inited_ = false; + bool inflate_inited_ = false; +}; + +std::unique_ptr CreateCompressor(char sign) { + switch (sign) { + case GZipCompressor::kId: + return std::make_unique(); + default: + return nullptr; + } +} + +std::unique_ptr CreateOutboundCompressor() { + auto algo = FLAGS_stream_compression_algo; + if (!algo) { + return nullptr; + } + switch (algo) { + case GZipCompressor::kIndex: + return std::make_unique(); + default: + YB_LOG_EVERY_N_SECS(DFATAL, 5) << "Unknown compression algorithm: " << algo; + return nullptr; + } +} + +class CompressedRefiner : public StreamRefiner { + public: + explicit CompressedRefiner(size_t receive_buffer_size, const MemTrackerPtr& buffer_tracker) + : read_buffer_(receive_buffer_size, buffer_tracker) { + } + + private: + void Start(RefinedStream* stream) override { + stream_ = stream; + } + + Result ProcessHeader(const IoVecs& data) override { + constexpr int kHeaderLen = 3; + + if (data[0].iov_len < kHeaderLen) { + // Did not receive enough bytes to make a decision. + // So just wait more bytes. + return 0; + } + + const uint8_t* bytes = static_cast(data[0].iov_base); + if (bytes[0] == 'Y' && bytes[1] == 'B') { + compressor_ = CreateCompressor(bytes[2]); + if (compressor_) { + RETURN_NOT_OK(compressor_->Init()); + RETURN_NOT_OK(stream_->StartHandshake()); + return kHeaderLen; + } + } + + // Don't use compression on this stream. + RETURN_NOT_OK(stream_->Established(RefinedStreamState::kDisabled)); + return 0; + } + + CHECKED_STATUS Send(OutboundDataPtr data) override { + boost::container::small_vector input; + data->Serialize(&input); + RefCntBuffer buffer; + RETURN_NOT_OK(compressor_->Compress(input, &buffer)); + VLOG_WITH_PREFIX(4) << __func__ << ", " << buffer.AsSlice().ToDebugString(); + auto compressed_data = std::make_shared( + std::move(buffer), std::move(data)); + return stream_->SendToLower(std::move(compressed_data)); + } + + CHECKED_STATUS Handshake() override { + if (stream_->local_side() == LocalSide::kClient) { + compressor_ = CreateOutboundCompressor(); + if (!compressor_) { + return stream_->Established(RefinedStreamState::kDisabled); + } + RETURN_NOT_OK(compressor_->Init()); + RETURN_NOT_OK(stream_->SendToLower(compressor_->ConnectionHeader())); + } + + return stream_->Established(RefinedStreamState::kEnabled); + } + + Result Read(void* buf, size_t num) override { + VLOG_WITH_PREFIX(4) << __func__ << ", " << num; + + auto io_vecs = read_buffer_.AppendedVecs(); + char* wpos = static_cast(buf); + char* wend = wpos + num; + for (const auto& io_vec : io_vecs) { + auto left = wend - wpos; + if (io_vec.iov_len >= left) { + memcpy(wpos, io_vec.iov_base, left); + wpos += left; + break; + } + memcpy(wpos, io_vec.iov_base, io_vec.iov_len); + wpos += io_vec.iov_len; + } + auto result = wpos - static_cast(buf); + read_buffer_.Consume(result, Slice()); + return result; + } + + Result Receive(const Slice& slice) override { + VLOG_WITH_PREFIX(4) << __func__ << ", " << slice.ToDebugString(); + + return compressor_->Decompress(slice, &read_buffer_); + } + + const Protocol* GetProtocol() override { + return CompressedStreamProtocol(); + } + + std::string ToString() const override { + return compressor_ ? compressor_->ToString() : "PLAIN"; + } + + const std::string& LogPrefix() const { + return stream_->LogPrefix(); + } + + RefinedStream* stream_ = nullptr; + std::unique_ptr compressor_ = nullptr; + CircularReadBuffer read_buffer_; +}; + +} // namespace + +const Protocol* CompressedStreamProtocol() { + static Protocol result("tcpc"); + return &result; +} + +StreamFactoryPtr CompressedStreamFactory( + StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker) { + return std::make_shared( + std::move(lower_layer_factory), buffer_tracker, + [](size_t receive_buffer_size, const MemTrackerPtr& buffer_tracker, + const StreamCreateData& data) { + return std::make_unique(receive_buffer_size, buffer_tracker); + }); +} + +} // namespace rpc +} // namespace yb diff --git a/src/yb/rpc/compressed_stream.h b/src/yb/rpc/compressed_stream.h new file mode 100644 index 000000000000..c1fee7f48301 --- /dev/null +++ b/src/yb/rpc/compressed_stream.h @@ -0,0 +1,31 @@ +// Copyright (c) YugaByte, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#ifndef YB_RPC_COMPRESSED_STREAM_H +#define YB_RPC_COMPRESSED_STREAM_H + +#include "yb/rpc/stream.h" + +#include "yb/util/mem_tracker.h" + +namespace yb { +namespace rpc { + +const Protocol* CompressedStreamProtocol(); +StreamFactoryPtr CompressedStreamFactory( + StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker); + +} // namespace rpc +} // namespace yb + +#endif // YB_RPC_COMPRESSED_STREAM_H diff --git a/src/yb/rpc/outbound_data.h b/src/yb/rpc/outbound_data.h index 4f1076d14cf4..4c76bd569c16 100644 --- a/src/yb/rpc/outbound_data.h +++ b/src/yb/rpc/outbound_data.h @@ -18,6 +18,7 @@ #include +#include "yb/util/format.h" #include "yb/util/memory/memory_usage.h" #include "yb/util/ref_cnt_buffer.h" @@ -88,6 +89,42 @@ class StringOutboundData : public OutboundData { std::string name_; }; +// OutboundData wrapper, that is used for altered streams, where we modify the data that should be +// sent. Examples could be that we encrypt or compress it. +// This wrapper would contain modified data and reference to original data, that will be used +// for notifications. +class SingleBufferOutboundData : public OutboundData { + public: + SingleBufferOutboundData(RefCntBuffer buffer, OutboundDataPtr lower_data) + : buffer_(std::move(buffer)), lower_data_(std::move(lower_data)) {} + + void Transferred(const Status& status, Connection* conn) override { + if (lower_data_) { + lower_data_->Transferred(status, conn); + } + } + + bool DumpPB(const DumpRunningRpcsRequestPB& req, RpcCallInProgressPB* resp) override { + return false; + } + + void Serialize(boost::container::small_vector_base* output) override { + output->push_back(std::move(buffer_)); + } + + std::string ToString() const override { + return Format("SingleBuffer[$0]", lower_data_); + } + + size_t ObjectSize() const override { return sizeof(*this); } + + size_t DynamicMemoryUsage() const override { return DynamicMemoryUsageOf(buffer_, lower_data_); } + + private: + RefCntBuffer buffer_; + OutboundDataPtr lower_data_; +}; + } // namespace rpc } // namespace yb diff --git a/src/yb/rpc/refined_stream.cc b/src/yb/rpc/refined_stream.cc new file mode 100644 index 000000000000..44301b11739b --- /dev/null +++ b/src/yb/rpc/refined_stream.cc @@ -0,0 +1,301 @@ +// Copyright (c) YugaByte, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#include "yb/rpc/refined_stream.h" + +#include "yb/rpc/rpc_util.h" + +#include "yb/util/logging.h" +#include "yb/util/size_literals.h" + +namespace yb { +namespace rpc { + +RefinedStream::RefinedStream( + std::unique_ptr lower_stream, std::unique_ptr refiner, + size_t receive_buffer_size, const MemTrackerPtr& buffer_tracker) + : lower_stream_(std::move(lower_stream)), refiner_(std::move(refiner)), + read_buffer_(receive_buffer_size, buffer_tracker) { +} + +size_t RefinedStream::GetPendingWriteBytes() { + return lower_stream_->GetPendingWriteBytes(); +} + +void RefinedStream::Close() { + lower_stream_->Close(); +} + +Status RefinedStream::TryWrite() { + return lower_stream_->TryWrite(); +} + +void RefinedStream::ParseReceived() { + lower_stream_->ParseReceived(); +} + +bool RefinedStream::Idle(std::string* reason) { + return lower_stream_->Idle(reason); +} + +void RefinedStream::DumpPB(const DumpRunningRpcsRequestPB& req, RpcConnectionPB* resp) { + lower_stream_->DumpPB(req, resp); +} + +const Endpoint& RefinedStream::Remote() const { + return lower_stream_->Remote(); +} + +const Endpoint& RefinedStream::Local() const { + return lower_stream_->Local(); +} + +Status RefinedStream::Start(bool connect, ev::loop_ref* loop, StreamContext* context) { + local_side_ = connect ? LocalSide::kClient : LocalSide::kServer; + context_ = context; + refiner_->Start(this); + return lower_stream_->Start(connect, loop, this); +} + +void RefinedStream::Shutdown(const Status& status) { + VLOG_WITH_PREFIX(1) << "Shutdown with status: " << status; + + for (auto& data : pending_data_) { + if (data) { + context().Transferred(data, status); + } + } + + pending_data_.clear(); + lower_stream_->Shutdown(status); +} + +Result RefinedStream::Send(OutboundDataPtr data) { + switch (state_) { + case RefinedStreamState::kInitial: + case RefinedStreamState::kHandshake: + pending_data_.push_back(std::move(data)); + return std::numeric_limits::max(); + case RefinedStreamState::kEnabled: + RETURN_NOT_OK(refiner_->Send(std::move(data))); + return std::numeric_limits::max(); + case RefinedStreamState::kDisabled: + return lower_stream_->Send(std::move(data)); + } + + FATAL_INVALID_ENUM_VALUE(RefinedStreamState, state_); +} + +void RefinedStream::UpdateLastActivity() { + context_->UpdateLastActivity(); +} + +void RefinedStream::UpdateLastRead() { + context_->UpdateLastRead(); +} + +void RefinedStream::UpdateLastWrite() { + context_->UpdateLastWrite(); +} + +void RefinedStream::Transferred(const OutboundDataPtr& data, const Status& status) { + context_->Transferred(data, status); +} + +void RefinedStream::Destroy(const Status& status) { + context_->Destroy(status); +} + +std::string RefinedStream::ToString() const { + return Format("$0[$1] $2 $3", + *refiner_, local_side_ == LocalSide::kClient ? "C" : "S", state_, *lower_stream_); +} + +void RefinedStream::Cancelled(size_t handle) { + LOG_WITH_PREFIX(DFATAL) << "Cancel is not supported for proxy stream: " << handle; +} + +bool RefinedStream::IsConnected() { + return state_ == RefinedStreamState::kEnabled || state_ == RefinedStreamState::kDisabled; +} + +const Protocol* RefinedStream::GetProtocol() { + return refiner_->GetProtocol(); +} + +StreamReadBuffer& RefinedStream::ReadBuffer() { + return read_buffer_; +} + +Result RefinedStream::ProcessReceived( + const IoVecs& data, ReadBufferFull read_buffer_full) { + switch (state_) { + case RefinedStreamState::kInitial: { + IoVecs data_copy = data; + auto consumed = VERIFY_RESULT(refiner_->ProcessHeader(data_copy)); + if (state_ == RefinedStreamState::kInitial) { + // Received data was not enough to check stream header. + RSTATUS_DCHECK_EQ(consumed, 0, InternalError, + "Consumed data while keeping stream in initial state"); + return ProcessDataResult{0, Slice()}; + } + data_copy[0].iov_len -= consumed; + data_copy[0].iov_base = static_cast(data_copy[0].iov_base) + consumed; + auto result = VERIFY_RESULT(ProcessReceived(data_copy, read_buffer_full)); + result.consumed += consumed; + return result; + } + + case RefinedStreamState::kDisabled: + return context_->ProcessReceived(data, read_buffer_full); + + case RefinedStreamState::kHandshake: FALLTHROUGH_INTENDED; + case RefinedStreamState::kEnabled: { + size_t result = 0; + for (const auto& iov : data) { + Slice slice(static_cast(iov.iov_base), iov.iov_len); + for (;;) { + auto len = VERIFY_RESULT(refiner_->Receive(slice)); + result += len; + if (len == slice.size()) { + break; + } + slice.remove_prefix(len); + RETURN_NOT_OK(HandshakeOrRead()); + } + } + RETURN_NOT_OK(HandshakeOrRead()); + return ProcessDataResult{ result, Slice() }; + } + } + + return STATUS_FORMAT(IllegalState, "Unexpected state: $0", to_underlying(state_)); +} + +void RefinedStream::Connected() { + if (local_side_ != LocalSide::kClient) { + return; + } + + auto status = StartHandshake(); + if (status.ok()) { + status = refiner_->Handshake(); + } + if (!status.ok()) { + context_->Destroy(status); + } +} + +Status RefinedStream::Established(RefinedStreamState state) { + state_ = state; + ResetLogPrefix(); + context().Connected(); + for (auto& data : pending_data_) { + RETURN_NOT_OK(Send(std::move(data))); + } + pending_data_.clear(); + return Status::OK(); +} + +Status RefinedStream::SendToLower(OutboundDataPtr data) { + return ResultToStatus(lower_stream_->Send(std::move(data))); +} + +Status RefinedStream::StartHandshake() { + state_ = RefinedStreamState::kHandshake; + ResetLogPrefix(); + return Status::OK(); +} + +Status RefinedStream::HandshakeOrRead() { + if (PREDICT_FALSE(state_ != RefinedStreamState::kEnabled)) { + auto handshake_status = refiner_->Handshake(); + LOG_IF_WITH_PREFIX(INFO, !handshake_status.ok()) << "Handshake failed: " << handshake_status; + RETURN_NOT_OK(handshake_status); + } + + if (state_ == RefinedStreamState::kEnabled) { + return Read(); + } + + return Status::OK(); +} + +Status RefinedStream::Read() { + auto& refined_read_buffer = context_->ReadBuffer(); + bool done = false; + while (!done) { + if (lower_stream_bytes_to_skip_ > 0) { + auto global_skip_buffer = GetGlobalSkipBuffer(); + do { + auto len = VERIFY_RESULT(refiner_->Read( + global_skip_buffer.mutable_data(), + std::min(global_skip_buffer.size(), lower_stream_bytes_to_skip_))); + if (len == 0) { + done = true; + break; + } + VLOG_WITH_PREFIX(4) << "Skip lower: " << len; + lower_stream_bytes_to_skip_ -= len; + } while (lower_stream_bytes_to_skip_ > 0); + } + auto out = VERIFY_RESULT(refined_read_buffer.PrepareAppend()); + size_t appended = 0; + for (auto iov = out.begin(); iov != out.end();) { + auto len = VERIFY_RESULT(refiner_->Read(iov->iov_base, iov->iov_len)); + if (len == 0) { + done = true; + break; + } + VLOG_WITH_PREFIX(4) << "Read lower: " << len; + appended += len; + iov->iov_base = static_cast(iov->iov_base) + len; + iov->iov_len -= len; + if (iov->iov_len <= 0) { + ++iov; + } + } + refined_read_buffer.DataAppended(appended); + if (refined_read_buffer.ReadyToRead()) { + auto temp = VERIFY_RESULT(context_->ProcessReceived( + refined_read_buffer.AppendedVecs(), ReadBufferFull(refined_read_buffer.Full()))); + refined_read_buffer.Consume(temp.consumed, temp.buffer); + DCHECK_EQ(lower_stream_bytes_to_skip_, 0); + lower_stream_bytes_to_skip_ = temp.bytes_to_skip; + } + } + + return Status::OK(); +} + +RefinedStreamFactory::RefinedStreamFactory( + StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker, + RefinerFactory refiner_factory) + : lower_layer_factory_(std::move(lower_layer_factory)), buffer_tracker_(buffer_tracker), + refiner_factory_(std::move(refiner_factory)) { +} + +std::unique_ptr RefinedStreamFactory::Create(const StreamCreateData& data) { + auto receive_buffer_size = data.socket->GetReceiveBufferSize(); + if (!receive_buffer_size.ok()) { + LOG(WARNING) << "Compressed stream failure: " << receive_buffer_size.status(); + receive_buffer_size = 256_KB; + } + auto lower_stream = lower_layer_factory_->Create(data); + return std::make_unique( + std::move(lower_stream), refiner_factory_(*receive_buffer_size, buffer_tracker_, data), + *receive_buffer_size, buffer_tracker_); +} + +} // namespace rpc +} // namespace yb diff --git a/src/yb/rpc/refined_stream.h b/src/yb/rpc/refined_stream.h new file mode 100644 index 000000000000..649810f5895a --- /dev/null +++ b/src/yb/rpc/refined_stream.h @@ -0,0 +1,129 @@ +// Copyright (c) YugaByte, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#ifndef YB_RPC_REFINED_STREAM_H +#define YB_RPC_REFINED_STREAM_H + +#include "yb/rpc/circular_read_buffer.h" +#include "yb/rpc/stream.h" + +#include "yb/util/mem_tracker.h" + +namespace yb { +namespace rpc { + +YB_DEFINE_ENUM(RefinedStreamState, (kInitial)(kHandshake)(kEnabled)(kDisabled)); +YB_DEFINE_ENUM(LocalSide, (kClient)(kServer)); + +// StreamRefiner is used by RefinedStream to perform actual stream data modification. +class StreamRefiner { + public: + virtual void Start(RefinedStream* stream) = 0; + virtual Result ProcessHeader(const IoVecs& data) = 0; + virtual CHECKED_STATUS Send(OutboundDataPtr data) = 0; + virtual CHECKED_STATUS Handshake() = 0; + virtual Result Read(void* buf, size_t num) = 0; + virtual Result Receive(const Slice& slice) = 0; + virtual const Protocol* GetProtocol() = 0; + + virtual std::string ToString() const = 0; + + virtual ~StreamRefiner() = default; +}; + +// Stream that alters the data sent and received by lower layer streams. +// For instance it could be used to compress or encrypt the data. +// +// RefinedStream keeps the code common to all such streams, +// while StreamRefiner provides actual data modification. +class RefinedStream : public Stream, public StreamContext { + public: + RefinedStream(std::unique_ptr lower_stream, std::unique_ptr refiner, + size_t receive_buffer_size, const MemTrackerPtr& buffer_tracker); + + size_t GetPendingWriteBytes() override; + void Close() override; + CHECKED_STATUS TryWrite() override; + void ParseReceived() override; + bool Idle(std::string* reason) override; + void DumpPB(const DumpRunningRpcsRequestPB& req, RpcConnectionPB* resp) override; + const Endpoint& Remote() const override; + const Endpoint& Local() const override; + CHECKED_STATUS Start(bool connect, ev::loop_ref* loop, StreamContext* context) override; + void Shutdown(const Status& status) override; + Result Send(OutboundDataPtr data) override; + void Cancelled(size_t handle) override; + bool IsConnected() override; + const Protocol* GetProtocol() override; + StreamReadBuffer& ReadBuffer() override; + std::string ToString() const override; + + // Implementation StreamContext + Result ProcessReceived( + const IoVecs& data, ReadBufferFull read_buffer_full) override; + void Connected() override; + + void UpdateLastActivity() override; + void UpdateLastRead() override; + void UpdateLastWrite() override; + void Transferred(const OutboundDataPtr& data, const Status& status) override; + void Destroy(const Status& status) override; + + CHECKED_STATUS Established(RefinedStreamState state); + CHECKED_STATUS SendToLower(OutboundDataPtr data); + CHECKED_STATUS StartHandshake(); + + StreamContext& context() const { + return *context_; + } + + LocalSide local_side() const { + return local_side_; + } + + private: + CHECKED_STATUS HandshakeOrRead(); + CHECKED_STATUS Read(); + + std::unique_ptr lower_stream_; + std::unique_ptr refiner_; + RefinedStreamState state_ = RefinedStreamState::kInitial; + StreamContext* context_ = nullptr; + std::vector pending_data_; + size_t lower_stream_bytes_to_skip_ = 0; + LocalSide local_side_ = LocalSide::kServer; + CircularReadBuffer read_buffer_; +}; + +class RefinedStreamFactory : public StreamFactory { + public: + using RefinerFactory = std::function( + size_t receive_buffer_size, const MemTrackerPtr& buffer_tracker, + const StreamCreateData& data)>; + + RefinedStreamFactory( + StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker, + RefinerFactory refiner_factory); + + private: + std::unique_ptr Create(const StreamCreateData& data) override; + + StreamFactoryPtr lower_layer_factory_; + MemTrackerPtr buffer_tracker_; + RefinerFactory refiner_factory_; +}; + +} // namespace rpc +} // namespace yb + +#endif // YB_RPC_REFINED_STREAM_H diff --git a/src/yb/rpc/rpc-test-base.h b/src/yb/rpc/rpc-test-base.h index ca805c78dba3..431e86b1e1cb 100644 --- a/src/yb/rpc/rpc-test-base.h +++ b/src/yb/rpc/rpc-test-base.h @@ -177,7 +177,7 @@ class RpcTestBase : public YBTest { RpcTestBase(); void TearDown() override; - protected: + std::unique_ptr CreateMessenger( const string &name, const MessengerOptions& options = kDefaultClientMessengerOptions); diff --git a/src/yb/rpc/rpc-test.cc b/src/yb/rpc/rpc-test.cc index 2c78da1a5c55..8b45494458ee 100644 --- a/src/yb/rpc/rpc-test.cc +++ b/src/yb/rpc/rpc-test.cc @@ -50,6 +50,7 @@ #include "yb/gutil/strings/human_readable.h" #include "yb/gutil/strings/join.h" +#include "yb/rpc/compressed_stream.h" #include "yb/rpc/secure_stream.h" #include "yb/rpc/serialization.h" #include "yb/rpc/tcp_stream.h" @@ -76,6 +77,7 @@ DECLARE_int32(rpc_throttle_threshold_bytes); DECLARE_bool(TEST_pause_calculator_echo_request); DECLARE_bool(binary_call_parser_reject_on_mem_tracker_hard_limit); DECLARE_string(vmodule); +DECLARE_int32(stream_compression_algo); using namespace std::chrono_literals; using std::string; @@ -83,6 +85,9 @@ using std::shared_ptr; using std::unordered_map; namespace yb { + +using rpc_test::CalculatorServiceProxy; + namespace rpc { class TestRpc : public RpcTestBase { @@ -1054,6 +1059,22 @@ TEST_F(TestRpc, CantAllocateReadBuffer) { TestCantAllocateReadBuffer(client_messenger.get(), server_addr); } +template +void RunTest(RpcTestBase* test, const MessengerFactory& messenger_factory, const F& f) { + auto client_messenger = rpc::CreateAutoShutdownMessengerHolder( + messenger_factory("Client", kDefaultClientMessengerOptions)); + auto proxy_cache = std::make_unique(client_messenger.get()); + + TestServerOptions options; + HostPort server_hostport; + test->StartTestServerWithGeneratedCode( + messenger_factory("TestServer", kDefaultServerMessengerOptions), &server_hostport, + options); + + CalculatorServiceProxy p(proxy_cache.get(), server_hostport, client_messenger->DefaultProtocol()); + f(&p); +} + class TestRpcSecure : public RpcTestBase { public: void SetUp() override { @@ -1075,30 +1096,86 @@ class TestRpcSecure : public RpcTestBase { } std::unique_ptr secure_context_; -}; - -TEST_F(TestRpcSecure, TLS) { - auto client_messenger = rpc::CreateAutoShutdownMessengerHolder(CreateSecureMessenger("Client")); - auto proxy_cache = std::make_unique(client_messenger.get()); - - TestServerOptions options; - HostPort server_hostport; - StartTestServerWithGeneratedCode( - CreateSecureMessenger("TestServer", kDefaultServerMessengerOptions), &server_hostport, - options); - rpc_test::CalculatorServiceProxy p(proxy_cache.get(), server_hostport, SecureStreamProtocol()); + template + void RunSecureTest(const F& f) { + RunTest(this, [this](const std::string& name, const MessengerOptions& options) { + return CreateSecureMessenger(name, options); + }, f); + } +}; +void TestSimple(CalculatorServiceProxy* proxy) { RpcController controller; - controller.set_timeout(5s); + controller.set_timeout(5s * kTimeMultiplier); rpc_test::AddRequestPB req; req.set_x(10); req.set_y(20); rpc_test::AddResponsePB resp; - ASSERT_OK(p.Add(req, &resp, &controller)); + ASSERT_OK(proxy->Add(req, &resp, &controller)); ASSERT_EQ(30, resp.result()); } +TEST_F(TestRpcSecure, TLS) { + RunSecureTest(&TestSimple); +} + +void TestBigOp(CalculatorServiceProxy* proxy) { + RpcController controller; + controller.set_timeout(5s * kTimeMultiplier); + rpc_test::EchoRequestPB req; + req.set_data(RandomHumanReadableString(4_MB)); + rpc_test::EchoResponsePB resp; + ASSERT_OK(proxy->Echo(req, &resp, &controller)); + ASSERT_EQ(req.data(), resp.data()); +} + +TEST_F(TestRpcSecure, BigOp) { + RunSecureTest(&TestBigOp); +} + +void TestManyOps(CalculatorServiceProxy* proxy) { + for (int i = 0; i != RegularBuildVsSanitizers(1000, 100); ++i) { + RpcController controller; + controller.set_timeout(5s * kTimeMultiplier); + rpc_test::EchoRequestPB req; + req.set_data(RandomHumanReadableString(4_KB)); + rpc_test::EchoResponsePB resp; + ASSERT_OK(proxy->Echo(req, &resp, &controller)); + ASSERT_EQ(req.data(), resp.data()); + } +} + +TEST_F(TestRpcSecure, ManyOps) { + RunSecureTest(&TestManyOps); +} + +void TestConcurrentOps(CalculatorServiceProxy* proxy) { + struct Op { + RpcController controller; + rpc_test::EchoRequestPB req; + rpc_test::EchoResponsePB resp; + }; + std::vector ops(RegularBuildVsSanitizers(1000, 100)); + CountDownLatch latch(ops.size()); + for (auto& op : ops) { + op.controller.set_timeout(5s * kTimeMultiplier); + op.req.set_data(RandomHumanReadableString(4_KB)); + proxy->EchoAsync(op.req, &op.resp, &op.controller, [&latch]() { + latch.CountDown(); + }); + } + latch.Wait(); + for (const auto& op : ops) { + ASSERT_OK(op.controller.status()); + ASSERT_EQ(op.req.data(), op.resp.data()); + } +} + +TEST_F(TestRpcSecure, ConcurrentOps) { + RunSecureTest(&TestConcurrentOps); +} + TEST_F(TestRpcSecure, CantAllocateReadBuffer) { // Set up server. TestServerOptions options = SetupServerForTestCantAllocateReadBuffer(); @@ -1112,5 +1189,47 @@ TEST_F(TestRpcSecure, CantAllocateReadBuffer) { TestCantAllocateReadBuffer(client_messenger.get(), server_addr); } +class TestRpcCompression : public RpcTestBase { + public: + void SetUp() override { + FLAGS_stream_compression_algo = 1; + RpcTestBase::SetUp(); + } + + protected: + std::unique_ptr CreateCompressedMessenger( + const std::string& name, const MessengerOptions& options = kDefaultClientMessengerOptions) { + auto builder = CreateMessengerBuilder(name, options); + builder.SetListenProtocol(CompressedStreamProtocol()); + builder.AddStreamFactory( + CompressedStreamProtocol(), + CompressedStreamFactory(TcpStream::Factory(), MemTracker::GetRootTracker())); + return EXPECT_RESULT(builder.Build()); + } + + template + void RunCompressionTest(const F& f) { + RunTest(this, [this](const std::string& name, const MessengerOptions& options) { + return CreateCompressedMessenger(name, options); + }, f); + } +}; + +TEST_F(TestRpcCompression, GZip) { + RunCompressionTest(&TestSimple); +} + +TEST_F(TestRpcCompression, BigOp) { + RunCompressionTest(&TestBigOp); +} + +TEST_F(TestRpcCompression, ManyOps) { + RunCompressionTest(&TestManyOps); +} + +TEST_F(TestRpcCompression, ConcurrentOps) { + RunCompressionTest(&TestConcurrentOps); +} + } // namespace rpc } // namespace yb diff --git a/src/yb/rpc/rpc_fwd.h b/src/yb/rpc/rpc_fwd.h index 92983b3cff61..5933dd763ca0 100644 --- a/src/yb/rpc/rpc_fwd.h +++ b/src/yb/rpc/rpc_fwd.h @@ -47,6 +47,7 @@ class RpcService; class Rpcs; class Poller; class Protocol; +class RefinedStream; class Scheduler; class SecureContext; class ServicePoolImpl; diff --git a/src/yb/rpc/secure_stream.cc b/src/yb/rpc/secure_stream.cc index 09d130cf83c4..4fc0a5b884e9 100644 --- a/src/yb/rpc/secure_stream.cc +++ b/src/yb/rpc/secure_stream.cc @@ -23,6 +23,7 @@ #include "yb/rpc/outbound_call.h" #include "yb/rpc/outbound_data.h" +#include "yb/rpc/refined_stream.h" #include "yb/rpc/rpc_util.h" #include "yb/util/enums.h" @@ -49,11 +50,20 @@ DEFINE_string(cipher_list, "", DEFINE_string(ciphersuites, "", "Define the available TLSv1.3 ciphersuites."); +#define YB_RPC_SSL_TYPE_DEFINE(name) \ + void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \ + BOOST_PP_CAT(name, _free)(value); \ + } \ + namespace yb { namespace rpc { namespace { +YB_RPC_SSL_TYPE_DECLARE(BIO); +YB_RPC_SSL_TYPE_DEFINE(BIO) +YB_RPC_SSL_TYPE_DECLARE(SSL); + const unsigned char kContextId[] = { 'Y', 'u', 'g', 'a', 'B', 'y', 't', 'e' }; std::string SSLErrorMessage(uint64_t error) { @@ -61,49 +71,12 @@ std::string SSLErrorMessage(uint64_t error) { return message ? message : "no error"; } -class SecureOutboundData : public OutboundData { - public: - SecureOutboundData(RefCntBuffer buffer, OutboundDataPtr lower_data) - : buffer_(std::move(buffer)), lower_data_(std::move(lower_data)) {} - - void Transferred(const Status& status, Connection* conn) override { - if (lower_data_) { - lower_data_->Transferred(status, conn); - } - } - - bool DumpPB(const DumpRunningRpcsRequestPB& req, RpcCallInProgressPB* resp) override { - return false; - } - - void Serialize(boost::container::small_vector_base* output) override { - output->push_back(std::move(buffer_)); - } - - std::string ToString() const override { - return Format("Secure[$0]", lower_data_); - } - - size_t ObjectSize() const override { return sizeof(*this); } - - size_t DynamicMemoryUsage() const override { return DynamicMemoryUsageOf(buffer_, lower_data_); } - - private: - RefCntBuffer buffer_; - OutboundDataPtr lower_data_; -}; - -#define YB_RPC_SSL_TYPE_DEFINE(name) \ - void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \ - BOOST_PP_CAT(name, _free)(value); \ - } \ - #define YB_RPC_SSL_TYPE(name) YB_RPC_SSL_TYPE_DECLARE(name) YB_RPC_SSL_TYPE_DEFINE(name) #define SSL_STATUS(type, format) STATUS_FORMAT(type, format, SSLErrorMessage(ERR_get_error())) -Result BIOFromSlice(const Slice& data) { - detail::BIOPtr bio(BIO_new_mem_buf(data.data(), data.size())); +Result BIOFromSlice(const Slice& data) { + BIOPtr bio(BIO_new_mem_buf(data.data(), data.size())); if (!bio) { return SSL_STATUS(IOError, "Create BIO failed: $0"); } @@ -276,7 +249,6 @@ int64_t ProtocolsOption() { namespace detail { -YB_RPC_SSL_TYPE_DEFINE(BIO) YB_RPC_SSL_TYPE_DEFINE(EVP_PKEY) YB_RPC_SSL_TYPE_DEFINE(SSL) YB_RPC_SSL_TYPE_DEFINE(SSL_CTX) @@ -389,121 +361,62 @@ Status SecureContext::UseCertificate(const Slice& data) { namespace { -YB_DEFINE_ENUM(LocalSide, (kClient)(kServer)); - -class SecureStream : public Stream, public StreamContext { +class SecureRefiner : public StreamRefiner { public: - SecureStream(const SecureContext& context, std::unique_ptr lower_stream, - size_t receive_buffer_size, const MemTrackerPtr& buffer_tracker, - const StreamCreateData& data) - : secure_context_(context), lower_stream_(std::move(lower_stream)), - remote_hostname_(data.remote_hostname), - encrypted_read_buffer_(receive_buffer_size, buffer_tracker) { - } - - SecureStream(const SecureStream&) = delete; - void operator=(const SecureStream&) = delete; - - size_t GetPendingWriteBytes() override { - return lower_stream_->GetPendingWriteBytes(); + SecureRefiner(const SecureContext& context, const StreamCreateData& data) + : secure_context_(context), remote_hostname_(data.remote_hostname) { } private: - CHECKED_STATUS Start(bool connect, ev::loop_ref* loop, StreamContext* context) override; - void Close() override; - void Shutdown(const Status& status) override; - Result Send(OutboundDataPtr data) override; - CHECKED_STATUS TryWrite() override; - void ParseReceived() override; - - void Cancelled(size_t handle) override { - LOG_WITH_PREFIX(DFATAL) << "Cancel is not supported for secure stream: " << handle; + void Start(RefinedStream* stream) override { + stream_ = stream; } - bool Idle(std::string* reason_not_idle) override; - bool IsConnected() override; - void DumpPB(const DumpRunningRpcsRequestPB& req, RpcConnectionPB* resp) override; + CHECKED_STATUS Handshake() override; + CHECKED_STATUS Init(); + + CHECKED_STATUS Send(OutboundDataPtr data) override; + Result ProcessHeader(const IoVecs& data) override; + Result Read(void* buf, size_t num) override; + Result Receive(const Slice& slice) override; - const Endpoint& Remote() override; - const Endpoint& Local() override; + std::string ToString() const override { + return "SECURE"; + } const Protocol* GetProtocol() override { return SecureStreamProtocol(); } - // Implementation StreamContext - void UpdateLastActivity() override; - void UpdateLastRead() override; - void UpdateLastWrite() override; - void Transferred(const OutboundDataPtr& data, const Status& status) override; - void Destroy(const Status& status) override; - Result ProcessReceived( - const IoVecs& data, ReadBufferFull read_buffer_full) override; - void Connected() override; - - StreamReadBuffer& ReadBuffer() override { - return encrypted_read_buffer_; - } - - CHECKED_STATUS Handshake(); - - CHECKED_STATUS Init(); - CHECKED_STATUS Established(SecureState state); static int VerifyCallback(int preverified, X509_STORE_CTX* store_context); CHECKED_STATUS Verify(bool preverified, X509_STORE_CTX* store_context); bool MatchEndpoint(X509* cert, GENERAL_NAMES* gens); bool MatchUid(X509* cert, GENERAL_NAMES* gens); bool MatchUidEntry(const Slice& value, const char* name); - CHECKED_STATUS SendEncrypted(OutboundDataPtr data); Result WriteEncrypted(OutboundDataPtr data); - CHECKED_STATUS ReadDecrypted(); - CHECKED_STATUS HandshakeOrRead(); - Result SslRead(void* buf, int num); - std::string ToString() override; + CHECKED_STATUS Established(RefinedStreamState state) { + VLOG_WITH_PREFIX(4) << "Established with state: " << state << ", used cipher: " + << SSL_get_cipher_name(ssl_.get()); + + return stream_->Established(state); + } + + const std::string& LogPrefix() const { + return stream_->LogPrefix(); + } const SecureContext& secure_context_; - std::unique_ptr lower_stream_; const std::string remote_hostname_; - StreamContext* context_; - size_t decrypted_bytes_to_skip_ = 0; - SecureState state_ = SecureState::kInitial; - LocalSide local_side_ = LocalSide::kServer; - bool connected_ = false; - std::vector pending_data_; + RefinedStream* stream_ = nullptr; std::vector certificate_entries_; - CircularReadBuffer encrypted_read_buffer_; - - detail::BIOPtr bio_; + BIOPtr bio_; detail::SSLPtr ssl_; Status verification_status_; }; -Status SecureStream::Start(bool connect, ev::loop_ref* loop, StreamContext* context) { - context_ = context; - local_side_ = connect ? LocalSide::kClient : LocalSide::kServer; - return lower_stream_->Start(connect, loop, this); -} - -void SecureStream::Close() { - lower_stream_->Close(); -} - -void SecureStream::Shutdown(const Status& status) { - VLOG_WITH_PREFIX(1) << "SecureStream::Shutdown with status: " << status; - - for (auto& data : pending_data_) { - if (data) { - context_->Transferred(data, status); - } - } - pending_data_.clear(); - - lower_stream_->Shutdown(status); -} - -Status SecureStream::SendEncrypted(OutboundDataPtr data) { +Status SecureRefiner::Send(OutboundDataPtr data) { boost::container::small_vector queue; data->Serialize(&queue); for (const auto& buf : queue) { @@ -532,23 +445,7 @@ Status SecureStream::SendEncrypted(OutboundDataPtr data) { return ResultToStatus(WriteEncrypted(std::move(data))); } -Result SecureStream::Send(OutboundDataPtr data) { - switch (state_) { - case SecureState::kInitial: - case SecureState::kHandshake: - pending_data_.push_back(std::move(data)); - return std::numeric_limits::max(); - case SecureState::kEnabled: - RETURN_NOT_OK(SendEncrypted(std::move(data))); - return std::numeric_limits::max(); - case SecureState::kDisabled: - return lower_stream_->Send(std::move(data)); - } - - FATAL_INVALID_ENUM_VALUE(SecureState, state_); -} - -Result SecureStream::WriteEncrypted(OutboundDataPtr data) { +Result SecureRefiner::WriteEncrypted(OutboundDataPtr data) { auto pending = BIO_ctrl_pending(bio_.get()); if (pending == 0) { return data ? STATUS(NetworkError, "No pending data during write") : Result(false); @@ -557,130 +454,38 @@ Result SecureStream::WriteEncrypted(OutboundDataPtr data) { auto len = BIO_read(bio_.get(), buf.data(), buf.size()); LOG_IF_WITH_PREFIX(DFATAL, len != buf.size()) << "BIO_read was not full: " << buf.size() << ", read: " << len; - VLOG_WITH_PREFIX(4) << "Write encrypted: " << len << ", " << yb::ToString(data); - RETURN_NOT_OK(lower_stream_->Send(std::make_shared(buf, std::move(data)))); + VLOG_WITH_PREFIX(4) << "Write encrypted: " << len << ", " << AsString(data); + RETURN_NOT_OK(stream_->SendToLower(std::make_shared( + buf, std::move(data)))); return true; } -Status SecureStream::TryWrite() { - return lower_stream_->TryWrite(); -} - -void SecureStream::ParseReceived() { - lower_stream_->ParseReceived(); -} - -bool SecureStream::Idle(std::string* reason) { - return lower_stream_->Idle(reason); -} - -bool SecureStream::IsConnected() { - return connected_; -} - -void SecureStream::DumpPB(const DumpRunningRpcsRequestPB& req, RpcConnectionPB* resp) { - lower_stream_->DumpPB(req, resp); -} - -const Endpoint& SecureStream::Remote() { - return lower_stream_->Remote(); -} - -const Endpoint& SecureStream::Local() { - return lower_stream_->Local(); -} - -std::string SecureStream::ToString() { - return Format("SECURE[$0] $1 $2", local_side_ == LocalSide::kClient ? "C" : "S", state_, - lower_stream_->ToString()); -} - -void SecureStream::UpdateLastActivity() { - context_->UpdateLastActivity(); -} - -void SecureStream::UpdateLastRead() { - context_->UpdateLastRead(); -} - -void SecureStream::UpdateLastWrite() { - context_->UpdateLastWrite(); -} - -void SecureStream::Transferred(const OutboundDataPtr& data, const Status& status) { - context_->Transferred(data, status); -} - -void SecureStream::Destroy(const Status& status) { - context_->Destroy(status); -} - -Result SecureStream::ProcessReceived( - const IoVecs& data, ReadBufferFull read_buffer_full) { - switch (state_) { - case SecureState::kInitial: { - if (data[0].iov_len < 2) { - return ProcessDataResult{0, Slice()}; - } - const uint8_t* bytes = static_cast(data[0].iov_base); - if (bytes[0] == 0x16 && bytes[1] == 0x03) { // TLS handshake header - state_ = SecureState::kHandshake; - ResetLogPrefix(); - RETURN_NOT_OK(Init()); - } else if (FLAGS_allow_insecure_connections) { - RETURN_NOT_OK(Established(SecureState::kDisabled)); - } else { - return STATUS_FORMAT(NetworkError, "Insecure connection header: $0", - Slice(bytes, 2).ToDebugHexString()); - } - return ProcessReceived(data, read_buffer_full); - } - - case SecureState::kDisabled: - return context_->ProcessReceived(data, read_buffer_full); - - case SecureState::kHandshake: FALLTHROUGH_INTENDED; - case SecureState::kEnabled: { - size_t result = 0; - for (const auto& iov : data) { - Slice slice(static_cast(iov.iov_base), iov.iov_len); - for (;;) { - auto len = BIO_write(bio_.get(), slice.data(), slice.size()); - result += len; - if (len == slice.size()) { - break; - } - slice.remove_prefix(len); - RETURN_NOT_OK(HandshakeOrRead()); - } - } - RETURN_NOT_OK(HandshakeOrRead()); - return ProcessDataResult{ result, Slice() }; - } +Result SecureRefiner::ProcessHeader(const IoVecs& data) { + if (data[0].iov_len < 2) { + return 0; } - return STATUS_FORMAT(IllegalState, "Unexpected state: $0", to_underlying(state_)); -} - -Status SecureStream::HandshakeOrRead() { - if (state_ == SecureState::kEnabled) { - return ReadDecrypted(); - } else { - auto handshake_status = Handshake(); - LOG_IF_WITH_PREFIX(INFO, !handshake_status.ok()) << "Handshake failed: " << handshake_status; - RETURN_NOT_OK(handshake_status); + const uint8_t* bytes = static_cast(data[0].iov_base); + if (bytes[0] == 0x16 && bytes[1] == 0x03) { // TLS handshake header + RETURN_NOT_OK(Init()); + RETURN_NOT_OK(stream_->StartHandshake()); + return 0; } - if (state_ == SecureState::kEnabled) { - return ReadDecrypted(); + + if (!FLAGS_allow_insecure_connections) { + return STATUS_FORMAT(NetworkError, "Insecure connection header: $0", + Slice(bytes, 2).ToDebugHexString()); } - return Status::OK(); + + RETURN_NOT_OK(Established(RefinedStreamState::kDisabled)); + return 0; } // Tries to do SSL_read up to num bytes from buf. Possible results: // > 0 - number of bytes actually read. // = 0 - in case of SSL_ERROR_WANT_READ. // Status with network error - in case of other errors. -Result SecureStream::SslRead(void* buf, int num) { +Result SecureRefiner::Read(void* buf, size_t num) { auto len = SSL_read(ssl_.get(), buf, num); if (len <= 0) { auto error = SSL_get_error(ssl_.get(), len); @@ -696,75 +501,21 @@ Result SecureStream::SslRead(void* buf, int num) { return len; } -Status SecureStream::ReadDecrypted() { - // TODO handle IsBusy - auto& decrypted_read_buffer = context_->ReadBuffer(); - bool done = false; - while (!done) { - if (decrypted_bytes_to_skip_ > 0) { - auto global_skip_buffer = GetGlobalSkipBuffer(); - do { - auto len = VERIFY_RESULT(SslRead( - global_skip_buffer.mutable_data(), - std::min(global_skip_buffer.size(), decrypted_bytes_to_skip_))); - if (len == 0) { - done = true; - break; - } - VLOG_WITH_PREFIX(4) << "Skip decrypted: " << len; - decrypted_bytes_to_skip_ -= len; - } while (decrypted_bytes_to_skip_ > 0); - } - auto out = VERIFY_RESULT(decrypted_read_buffer.PrepareAppend()); - size_t appended = 0; - for (auto iov = out.begin(); iov != out.end();) { - auto len = VERIFY_RESULT(SslRead(iov->iov_base, iov->iov_len)); - if (len == 0) { - done = true; - break; - } - VLOG_WITH_PREFIX(4) << "Read decrypted: " << len; - appended += len; - iov->iov_base = static_cast(iov->iov_base) + len; - iov->iov_len -= len; - if (iov->iov_len <= 0) { - ++iov; - } - } - decrypted_read_buffer.DataAppended(appended); - if (decrypted_read_buffer.ReadyToRead()) { - auto temp = VERIFY_RESULT(context_->ProcessReceived( - decrypted_read_buffer.AppendedVecs(), ReadBufferFull(decrypted_read_buffer.Full()))); - decrypted_read_buffer.Consume(temp.consumed, temp.buffer); - DCHECK_EQ(decrypted_bytes_to_skip_, 0); - decrypted_bytes_to_skip_ = temp.bytes_to_skip; - } - } - - return Status::OK(); +Result SecureRefiner::Receive(const Slice& slice) { + return BIO_write(bio_.get(), slice.data(), slice.size()); } -void SecureStream::Connected() { - if (local_side_ == LocalSide::kClient) { - auto status = Init(); - if (status.ok()) { - status = Handshake(); - } - if (!status.ok()) { - context_->Destroy(status); - } - } -} +Status SecureRefiner::Handshake() { + RETURN_NOT_OK(Init()); -Status SecureStream::Handshake() { for (;;) { - if (state_ == SecureState::kEnabled) { + if (stream_->IsConnected()) { return Status::OK(); } auto pending_before = BIO_ctrl_pending(bio_.get()); ERR_clear_error(); - int result = local_side_ == LocalSide::kClient + int result = stream_->local_side() == LocalSide::kClient ? SSL_connect(ssl_.get()) : SSL_accept(ssl_.get()); int ssl_error = SSL_get_error(ssl_.get(), result); int sys_error = static_cast(ERR_get_error()); @@ -779,7 +530,7 @@ Status SecureStream::Handshake() { message_suffix = Format(", certificate entries: $0", certificate_entries_); } return STATUS_FORMAT(NetworkError, "Handshake failed: $0, address: $1, hostname: $2$3", - message, Remote().address(), remote_hostname_, message_suffix); + message, stream_->Remote().address(), remote_hostname_, message_suffix); } if (ssl_error == SSL_ERROR_WANT_WRITE || pending_after > pending_before) { @@ -787,12 +538,12 @@ Status SecureStream::Handshake() { RefCntBuffer buffer(pending_after); int len = BIO_read(bio_.get(), buffer.data(), buffer.size()); DCHECK_EQ(len, pending_after); - auto data = std::make_shared(buffer, nullptr); - RETURN_NOT_OK(lower_stream_->Send(data)); + RETURN_NOT_OK(stream_->SendToLower( + std::make_shared(buffer, nullptr))); // If SSL_connect/SSL_accept returned positive result it means that TLS connection // was succesfully established. We just have to send last portion of data. if (result > 0) { - RETURN_NOT_OK(Established(SecureState::kEnabled)); + RETURN_NOT_OK(Established(RefinedStreamState::kEnabled)); } } else if (ssl_error == SSL_ERROR_WANT_READ) { // SSL expects that we would read from underlying transport. @@ -800,80 +551,73 @@ Status SecureStream::Handshake() { } else if (SSL_get_shutdown(ssl_.get()) & SSL_RECEIVED_SHUTDOWN) { return STATUS(Aborted, "Handshake aborted"); } else { - return Established(SecureState::kEnabled); + return Established(RefinedStreamState::kEnabled); } } } -Status SecureStream::Init() { - if (!ssl_) { - ssl_ = secure_context_.Create(); - SSL_set_mode(ssl_.get(), SSL_MODE_ENABLE_PARTIAL_WRITE); - SSL_set_mode(ssl_.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - SSL_set_mode(ssl_.get(), SSL_MODE_RELEASE_BUFFERS); - SSL_set_app_data(ssl_.get(), this); - - if (local_side_ == LocalSide::kServer || secure_context_.use_client_certificate()) { - auto res = SSL_use_PrivateKey(ssl_.get(), secure_context_.private_key()); - if (res != 1) { - return SSL_STATUS(InvalidArgument, "Failed to use private key: $0"); - } - res = SSL_use_certificate(ssl_.get(), secure_context_.certificate()); - if (res != 1) { - return SSL_STATUS(InvalidArgument, "Failed to use certificate: $0"); - } - } +Status SecureRefiner::Init() { + if (ssl_) { + return Status::OK(); + } - BIO* int_bio = nullptr; - BIO* temp_bio = nullptr; - BIO_new_bio_pair(&int_bio, 0, &temp_bio, 0); - SSL_set_bio(ssl_.get(), int_bio, int_bio); - bio_.reset(temp_bio); + ssl_ = secure_context_.Create(); + SSL_set_mode(ssl_.get(), SSL_MODE_ENABLE_PARTIAL_WRITE); + SSL_set_mode(ssl_.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + SSL_set_mode(ssl_.get(), SSL_MODE_RELEASE_BUFFERS); + SSL_set_app_data(ssl_.get(), this); - int verify_mode = SSL_VERIFY_PEER; - if (secure_context_.require_client_certificate()) { - verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE; + if (stream_->local_side() == LocalSide::kServer || secure_context_.use_client_certificate()) { + auto res = SSL_use_PrivateKey(ssl_.get(), secure_context_.private_key()); + if (res != 1) { + return SSL_STATUS(InvalidArgument, "Failed to use private key: $0"); + } + res = SSL_use_certificate(ssl_.get(), secure_context_.certificate()); + if (res != 1) { + return SSL_STATUS(InvalidArgument, "Failed to use certificate: $0"); } - SSL_set_verify(ssl_.get(), verify_mode, &VerifyCallback); } + BIO* int_bio = nullptr; + BIO* temp_bio = nullptr; + BIO_new_bio_pair(&int_bio, 0, &temp_bio, 0); + SSL_set_bio(ssl_.get(), int_bio, int_bio); + bio_.reset(temp_bio); + + int verify_mode = SSL_VERIFY_PEER; + if (secure_context_.require_client_certificate()) { + verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE; + } + SSL_set_verify(ssl_.get(), verify_mode, &VerifyCallback); + return Status::OK(); } -Status SecureStream::Established(SecureState state) { - VLOG_WITH_PREFIX(4) << "Established with state: " << state << ", used cipher: " - << SSL_get_cipher_name(ssl_.get()); +int SecureRefiner::VerifyCallback(int preverified, X509_STORE_CTX* store_context) { + if (!store_context) { + return preverified; + } - state_ = state; - ResetLogPrefix(); - connected_ = true; - context_->Connected(); - for (auto& data : pending_data_) { - RETURN_NOT_OK(Send(std::move(data))); + auto ssl = static_cast(X509_STORE_CTX_get_ex_data( + store_context, SSL_get_ex_data_X509_STORE_CTX_idx())); + if (!ssl) { + return preverified; } - pending_data_.clear(); - return Status::OK(); -} -int SecureStream::VerifyCallback(int preverified, X509_STORE_CTX* store_context) { - if (store_context) { - auto ssl = static_cast(X509_STORE_CTX_get_ex_data( - store_context, SSL_get_ex_data_X509_STORE_CTX_idx())); - if (ssl) { - auto stream = static_cast(SSL_get_app_data(ssl)); - if (stream) { - auto status = stream->Verify(preverified != 0, store_context); - if (!status.ok()) { - VLOG(4) << stream->LogPrefix() << status; - stream->verification_status_ = status; - return 0; - } - return 1; - } - } + auto refiner = static_cast(SSL_get_app_data(ssl)); + + if (!refiner) { + return preverified; + } + + auto status = refiner->Verify(preverified != 0, store_context); + if (status.ok()) { + return 1; } - return preverified; + VLOG(4) << refiner->LogPrefix() << status; + refiner->verification_status_ = status; + return 0; } namespace { @@ -937,8 +681,8 @@ Slice GetCommonName(X509* cert) { } // namespace -bool SecureStream::MatchEndpoint(X509* cert, GENERAL_NAMES* gens) { - auto address = Remote().address(); +bool SecureRefiner::MatchEndpoint(X509* cert, GENERAL_NAMES* gens) { + auto address = stream_->Remote().address(); for (int i = 0; i < sk_GENERAL_NAME_num(gens); ++i) { GENERAL_NAME* gen = sk_GENERAL_NAME_value(gens, i); @@ -989,19 +733,20 @@ bool SecureStream::MatchEndpoint(X509* cert, GENERAL_NAMES* gens) { Slice common_name = GetCommonName(cert); if (!common_name.empty()) { VLOG_WITH_PREFIX(4) << "Common name: " << common_name.ToBuffer() << " vs " - << Remote().address() << "/" << remote_hostname_; - if (common_name == Remote().address().to_string() || + << stream_->Remote().address() << "/" << remote_hostname_; + if (common_name == stream_->Remote().address().to_string() || MatchPattern(common_name, remote_hostname_)) { return true; } } - VLOG_WITH_PREFIX(4) << "Nothing suitable for " << Remote().address() << "/" << remote_hostname_; + VLOG_WITH_PREFIX(4) << "Nothing suitable for " << stream_->Remote().address() << "/" + << remote_hostname_; return false; } -bool SecureStream::MatchUidEntry(const Slice& value, const char* name) { +bool SecureRefiner::MatchUidEntry(const Slice& value, const char* name) { if (value == secure_context_.required_uid()) { VLOG_WITH_PREFIX(4) << "Accepted " << name << ": " << value.ToBuffer(); return true; @@ -1028,7 +773,7 @@ bool IsStringType(int type) { return false; } -bool SecureStream::MatchUid(X509* cert, GENERAL_NAMES* gens) { +bool SecureRefiner::MatchUid(X509* cert, GENERAL_NAMES* gens) { if (MatchUidEntry(GetCommonName(cert), "common name")) { return true; } @@ -1066,7 +811,7 @@ bool SecureStream::MatchUid(X509* cert, GENERAL_NAMES* gens) { } // Verify according to RFC 2818. -Status SecureStream::Verify(bool preverified, X509_STORE_CTX* store_context) { +Status SecureRefiner::Verify(bool preverified, X509_STORE_CTX* store_context) { // Don't bother looking at certificates that have failed pre-verification. if (!preverified) { auto err = X509_STORE_CTX_get_error(store_context); @@ -1101,8 +846,8 @@ Status SecureStream::Verify(bool preverified, X509_STORE_CTX* store_context) { VLOG_WITH_PREFIX(4) << "Skip UID verification"; } - bool verify_endpoint = local_side_ == LocalSide::kClient ? FLAGS_verify_server_endpoint - : FLAGS_verify_client_endpoint; + bool verify_endpoint = stream_->local_side() == LocalSide::kClient ? FLAGS_verify_server_endpoint + : FLAGS_verify_client_endpoint; if (verify_endpoint) { if (!MatchEndpoint(cert, gens)) { return STATUS(NetworkError, "Endpoint does not match"); @@ -1124,34 +869,12 @@ const Protocol* SecureStreamProtocol() { StreamFactoryPtr SecureStreamFactory( StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker, const SecureContext* context) { - class SecureStreamFactory : public StreamFactory { - public: - SecureStreamFactory( - StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker, - const SecureContext* context) - : lower_layer_factory_(std::move(lower_layer_factory)), buffer_tracker_(buffer_tracker), - context_(context) { - } - - private: - std::unique_ptr Create(const StreamCreateData& data) override { - auto receive_buffer_size = data.socket->GetReceiveBufferSize(); - if (!receive_buffer_size.ok()) { - LOG(WARNING) << "Secure stream failure: " << receive_buffer_size.status(); - receive_buffer_size = 256_KB; - } - auto lower_stream = lower_layer_factory_->Create(data); - return std::make_unique( - *context_, std::move(lower_stream), *receive_buffer_size, buffer_tracker_, data); - } - - StreamFactoryPtr lower_layer_factory_; - MemTrackerPtr buffer_tracker_; - const SecureContext* context_; - }; - - return std::make_shared( - std::move(lower_layer_factory), buffer_tracker, context); + return std::make_shared( + std::move(lower_layer_factory), buffer_tracker, + [context](size_t receive_buffer_size, const MemTrackerPtr& buffer_tracker, + const StreamCreateData& data) { + return std::make_unique(*context, data); + }); } } // namespace rpc diff --git a/src/yb/rpc/secure_stream.h b/src/yb/rpc/secure_stream.h index 710392bf9f77..6f5dd5946b79 100644 --- a/src/yb/rpc/secure_stream.h +++ b/src/yb/rpc/secure_stream.h @@ -20,18 +20,14 @@ #include "yb/util/enums.h" -typedef struct bio_st BIO; typedef struct evp_pkey_st EVP_PKEY; typedef struct ssl_st SSL; typedef struct ssl_ctx_st SSL_CTX; typedef struct x509_st X509; -typedef struct x509_store_ctx_st X509_STORE_CTX; namespace yb { namespace rpc { -YB_DEFINE_ENUM(SecureState, (kInitial)(kHandshake)(kEnabled)(kDisabled)); - #define YB_RPC_SSL_TYPE_DECLARE(name) \ struct BOOST_PP_CAT(name, Free) { \ void operator()(name* value) const; \ @@ -42,7 +38,6 @@ YB_DEFINE_ENUM(SecureState, (kInitial)(kHandshake)(kEnabled)(kDisabled)); namespace detail { -YB_RPC_SSL_TYPE_DECLARE(BIO); YB_RPC_SSL_TYPE_DECLARE(EVP_PKEY); YB_RPC_SSL_TYPE_DECLARE(SSL); YB_RPC_SSL_TYPE_DECLARE(SSL_CTX); diff --git a/src/yb/rpc/stream.h b/src/yb/rpc/stream.h index 63fad49f7099..b3a373b88c8e 100644 --- a/src/yb/rpc/stream.h +++ b/src/yb/rpc/stream.h @@ -101,6 +101,11 @@ class StreamContext { class Stream { public: + Stream() = default; + + Stream(const Stream&) = delete; + void operator=(const Stream&) = delete; + virtual CHECKED_STATUS Start(bool connect, ev::loop_ref* loop, StreamContext* context) = 0; virtual void Close() = 0; virtual void Shutdown(const Status& status) = 0; @@ -120,12 +125,12 @@ class Stream { virtual void DumpPB(const DumpRunningRpcsRequestPB& req, RpcConnectionPB* resp) = 0; // The address of the remote end of the connection. - virtual const Endpoint& Remote() = 0; + virtual const Endpoint& Remote() const = 0; // The address of the local end of the connection. - virtual const Endpoint& Local() = 0; + virtual const Endpoint& Local() const = 0; - virtual std::string ToString() { + virtual std::string ToString() const { return Format("{ local: $0 remote: $1 }", Local(), Remote()); } diff --git a/src/yb/rpc/tcp_stream.h b/src/yb/rpc/tcp_stream.h index 22d296d0409d..ba4177071573 100644 --- a/src/yb/rpc/tcp_stream.h +++ b/src/yb/rpc/tcp_stream.h @@ -80,8 +80,8 @@ class TcpStream : public Stream { bool IsConnected() override { return connected_; } void DumpPB(const DumpRunningRpcsRequestPB& req, RpcConnectionPB* resp) override; - const Endpoint& Remote() override { return remote_; } - const Endpoint& Local() override { return local_; } + const Endpoint& Remote() const override { return remote_; } + const Endpoint& Local() const override { return local_; } const Protocol* GetProtocol() override { return StaticProtocol(); diff --git a/src/yb/rpc/yb_rpc.cc b/src/yb/rpc/yb_rpc.cc index b8dafdd27d43..d50c4e616b27 100644 --- a/src/yb/rpc/yb_rpc.cc +++ b/src/yb/rpc/yb_rpc.cc @@ -69,8 +69,8 @@ const char kConnectionHeaderBytes[] = "YB\1"; const size_t kConnectionHeaderSize = sizeof(kConnectionHeaderBytes) - 1; OutboundDataPtr ConnectionHeaderInstance() { - static OutboundDataPtr result( - new StringOutboundData(kConnectionHeaderBytes, kConnectionHeaderSize, "ConnectionHeader")); + static OutboundDataPtr result = std::make_shared( + kConnectionHeaderBytes, kConnectionHeaderSize, "ConnectionHeader"); return result; }