Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions source/extensions/filters/http/proto_api_scrubber/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using proto_processing_lib::proto_scrubber::ScrubberContext;

const char kRcDetailFilterProtoApiScrubber[] = "proto_api_scrubber";
const char kRcDetailErrorRequestBufferConversion[] = "REQUEST_BUFFER_CONVERSION_FAIL";
const char kRcDetailErrorResponseBufferConversion[] = "RESPONSE_BUFFER_CONVERSION_FAIL";
const char kRcDetailErrorTypeBadRequest[] = "BAD_REQUEST";
const char kPathValidationError[] = "Error in `:path` header validation.";

Expand Down Expand Up @@ -204,6 +205,118 @@ Http::FilterDataStatus ProtoApiScrubberFilter::decodeData(Buffer::Instance& data
return Envoy::Http::FilterDataStatus::Continue;
}

Http::FilterHeadersStatus
ProtoApiScrubberFilter::encodeHeaders(Envoy::Http::ResponseHeaderMap& headers, bool end_stream) {
ENVOY_STREAM_LOG(trace, "Called ProtoApiScrubber Filter encodeHeaders", *encoder_callbacks_);

if (!Envoy::Grpc::Common::isGrpcResponseHeaders(headers, end_stream)) {
ENVOY_STREAM_LOG(
debug,
"Response headers is NOT application/grpc content-type. Response is passed through "
"without message extraction.",
*encoder_callbacks_);
return Envoy::Http::FilterHeadersStatus::Continue;
}

auto cord_message_data_factory = std::make_unique<CreateMessageDataFunc>(
[]() { return std::make_unique<Protobuf::field_extraction::CordMessageData>(); });

response_msg_converter_ = std::make_unique<MessageConverter>(
std::move(cord_message_data_factory), encoder_callbacks_->encoderBufferLimit());

return Envoy::Http::FilterHeadersStatus::Continue;
}

Http::FilterDataStatus ProtoApiScrubberFilter::encodeData(Buffer::Instance& data, bool end_stream) {
ENVOY_STREAM_LOG(debug, "Called ProtoApiScrubber::encodeData: data size={} end_stream={}",
*encoder_callbacks_, data.length(), end_stream);

if (!response_msg_converter_) {
return Envoy::Http::FilterDataStatus::Continue;
}

// Move the data to internal gRPC buffer messages representation.
auto messages = response_msg_converter_->accumulateMessages(data, end_stream);
if (const absl::Status& status = messages.status(); !status.ok()) {
rejectResponse(status.raw_code(), status.message(),
formatError(kRcDetailFilterProtoApiScrubber,
absl::StatusCodeToString(status.code()),
kRcDetailErrorResponseBufferConversion));
return Envoy::Http::FilterDataStatus::StopIterationNoBuffer;
}

if (messages->empty()) {
ENVOY_STREAM_LOG(debug, "not a complete msg", *encoder_callbacks_);
return Http::FilterDataStatus::StopIterationNoBuffer;
}

// Scrub each message individually, one by one.
ENVOY_STREAM_LOG(trace, "Accumulated {} messages. Starting scrubbing on each of them one by one.",
*encoder_callbacks_, messages->size());

// Only create the response scrubber if it's not already created.
if (!response_scrubber_) {
absl::StatusOr<std::unique_ptr<ProtoScrubber>> response_scrubber_or_status =
createResponseProtoScrubber();
if (!response_scrubber_or_status.ok()) {
const absl::Status& status = response_scrubber_or_status.status();
ENVOY_STREAM_LOG(error, "Unable to scrub request payload. Error details: {}",
*encoder_callbacks_, status.ToString());
rejectResponse(status.raw_code(), status.message(),
formatError(kRcDetailFilterProtoApiScrubber,
absl::StatusCodeToString(status.code()),
kRcDetailErrorTypeBadRequest));
return Envoy::Http::FilterDataStatus::StopIterationNoBuffer;
}

// Move the created scrubber into the member variable
response_scrubber_ = std::move(response_scrubber_or_status).value();
}

for (size_t msg_idx = 0; msg_idx < messages->size(); ++msg_idx) {
std::unique_ptr<StreamMessage> stream_message = std::move(messages->at(msg_idx));
if (stream_message->message() == nullptr) {
// Expect end_stream=true when the MessageConverter signals an stream end.
ASSERT(end_stream);
// Expect message_data->isFinalMessage()=true when the MessageConverter signals an stream end.
ASSERT(stream_message->isFinalMessage());
// Expect message_data is the last element in the vector when the MessageConverter signals an
// stream end.
ASSERT(msg_idx == messages->size() - 1);
// Skip the empty message
continue;
}

auto response_scrubber_or_status = response_scrubber_->Scrub(stream_message->message());
if (!response_scrubber_or_status.ok()) {
ENVOY_STREAM_LOG(warn,
"Response scrubbing failed with error: {}. The response will not be "
"modified.",
*encoder_callbacks_, response_scrubber_or_status.ToString());
}

auto buf_convert_status =
response_msg_converter_->convertBackToBuffer(std::move(stream_message));
if (!buf_convert_status.ok()) {
const absl::Status& status = buf_convert_status.status();
ENVOY_STREAM_LOG(error, "Failed to convert scrubbed message back to envoy buffer: {}",
*encoder_callbacks_, status.ToString());

// Send a local reply if response conversion failed.
rejectResponse(status.raw_code(), status.message(),
formatError(kRcDetailFilterProtoApiScrubber,
absl::StatusCodeToString(status.code()),
kRcDetailErrorResponseBufferConversion));
return Envoy::Http::FilterDataStatus::StopIterationNoBuffer;
}

data.move(*buf_convert_status.value());
}

ENVOY_STREAM_LOG(trace, "Response scrubbing completed successfully.", *encoder_callbacks_);
return Envoy::Http::FilterDataStatus::Continue;
}

absl::StatusOr<std::unique_ptr<ProtoScrubber>>
ProtoApiScrubberFilter::createRequestProtoScrubber() {
absl::StatusOr<const Protobuf::Type*> request_type_or_status =
Expand All @@ -221,6 +334,22 @@ ProtoApiScrubberFilter::createRequestProtoScrubber() {
ScrubberContext::kRequestScrubbing, false);
}

absl::StatusOr<std::unique_ptr<ProtoScrubber>>
ProtoApiScrubberFilter::createResponseProtoScrubber() {
absl::StatusOr<const Protobuf::Type*> response_type_or_status =
filter_config_.getResponseType(method_name_);
RETURN_IF_NOT_OK(response_type_or_status.status());

response_match_tree_field_checker_ = std::make_unique<FieldChecker>(
ScrubberContext::kResponseScrubbing, &encoder_callbacks_->streamInfo(), method_name_,
&filter_config_);

return std::make_unique<ProtoScrubber>(
response_type_or_status.value(), filter_config_.getTypeFinder(),
std::vector<const FieldCheckerInterface*>{response_match_tree_field_checker_.get()},
ScrubberContext::kResponseScrubbing, false);
}

void ProtoApiScrubberFilter::rejectRequest(Envoy::Grpc::Status::GrpcStatus grpc_status,
absl::string_view error_msg,
absl::string_view rc_detail) {
Expand All @@ -231,6 +360,16 @@ void ProtoApiScrubberFilter::rejectRequest(Envoy::Grpc::Status::GrpcStatus grpc_
grpc_status, rc_detail);
}

void ProtoApiScrubberFilter::rejectResponse(Envoy::Grpc::Status::GrpcStatus grpc_status,
absl::string_view error_msg,
absl::string_view rc_detail) {
ENVOY_STREAM_LOG(debug, "Rejecting response grpcStatus={}, message={}", *encoder_callbacks_,
grpc_status, error_msg);
encoder_callbacks_->sendLocalReply(
static_cast<Envoy::Http::Code>(Utility::grpcToHttpStatus(grpc_status)), error_msg, nullptr,
grpc_status, rc_detail);
}

} // namespace ProtoApiScrubber
} // namespace HttpFilters
} // namespace Extensions
Expand Down
28 changes: 28 additions & 0 deletions source/extensions/filters/http/proto_api_scrubber/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,36 @@ class ProtoApiScrubberFilter : public Http::PassThroughFilter,

Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override;

Http::FilterHeadersStatus encodeHeaders(Envoy::Http::ResponseHeaderMap& headers,
bool end_stream) override;

Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override;

private:
// Rejects requests and sends local reply back to the client.
void rejectRequest(Envoy::Grpc::Status::GrpcStatus grpc_status, absl::string_view error_msg,
absl::string_view rc_detail);

// Rejects response and sends local reply back to the client.
void rejectResponse(Envoy::Grpc::Status::GrpcStatus grpc_status, absl::string_view error_msg,
absl::string_view rc_detail);

bool is_valid_grpc_request_ = false;

// Request message converter which converts Envoy Buffer data to StreamMessage (for scrubbing) and
// vice-versa.
GrpcFieldExtraction::MessageConverterPtr request_msg_converter_{nullptr};

// Response message converter which converts Envoy Buffer data to StreamMessage (for scrubbing)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level comment:
looking at the request vs response related fields, it seems that there are the same group of objects needed for each of these. Using a generic struct with the relevant fields that is initialized differently for each of them may reduce the copy-paste between the two.
(not AI for this PR, just high-level observation to take into account).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. I'll add that in the backlog for now, will pick it up once the current lined-up items are complete for the filter.

// and vice-versa.
GrpcFieldExtraction::MessageConverterPtr response_msg_converter_{nullptr};

// Creates and returns an instance of `ProtoScrubber` which can be used for request scrubbing.
absl::StatusOr<std::unique_ptr<ProtoScrubber>> createRequestProtoScrubber();

// Creates and returns an instance of `ProtoScrubber` which can be used for response scrubbing.
absl::StatusOr<std::unique_ptr<ProtoScrubber>> createResponseProtoScrubber();

const ProtoApiScrubberFilterConfig& filter_config_;

// Stores the full gRPC method name e.g., `/package.service/method`.
Expand All @@ -69,6 +85,18 @@ class ProtoApiScrubberFilter : public Http::PassThroughFilter,
// once per request, preserving state across multiple data frames (e.g., for
// gRPC streaming or large payloads).
std::unique_ptr<ProtoScrubber> request_scrubber_;

// The field checker which uses match tree configured in the filter config to determine whether a
// field should be preserved or removed from the response protobuf payloads.
// NOTE: This must outlive `response_scrubber_`, which holds a non-owning reference to this
// instance.
std::unique_ptr<FieldCheckerInterface> response_match_tree_field_checker_;

// The scrubber instance for the response path.
// It is lazily initialized in encodeData() to ensure it is instantiated exactly
// once per request, preserving state across multiple data frames (e.g., for
// gRPC streaming or large payloads).
std::unique_ptr<ProtoScrubber> response_scrubber_;
};

class FilterFactory : public Common::FactoryBase<ProtoApiScrubberConfig> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ void ProtoApiScrubberFilterConfig::initializeTypeUtils() {
});
}

absl::StatusOr<const Protobuf::Type*>
ProtoApiScrubberFilterConfig::getRequestType(const std::string& method_name) const {
absl::StatusOr<const MethodDescriptor*>
ProtoApiScrubberFilterConfig::getMethodDescriptor(const std::string& method_name) const {
// Covert grpc method name from `/package.service/method` format to `package.service.method` as
// the method `FindMethodByName` expects the method name to be in the latter format.
std::string dot_separated_method_name =
Expand All @@ -259,12 +259,32 @@ ProtoApiScrubberFilterConfig::getRequestType(const std::string& method_name) con
dot_separated_method_name));
}

std::string request_type_url =
absl::StrCat(Envoy::Grpc::Common::typeUrlPrefix(), "/", method->input_type()->full_name());
return method;
}

absl::StatusOr<const Protobuf::Type*>
ProtoApiScrubberFilterConfig::getRequestType(const std::string& method_name) const {
absl::StatusOr<const MethodDescriptor*> method_or_status = getMethodDescriptor(method_name);
RETURN_IF_NOT_OK(method_or_status.status());

std::string request_type_url = absl::StrCat(Envoy::Grpc::Common::typeUrlPrefix(), "/",
method_or_status.value()->input_type()->full_name());
const Protobuf::Type* request_type = (*type_finder_)(request_type_url);
return request_type;
}

absl::StatusOr<const Protobuf::Type*>
ProtoApiScrubberFilterConfig::getResponseType(const std::string& method_name) const {
absl::StatusOr<const MethodDescriptor*> method_or_status = getMethodDescriptor(method_name);
RETURN_IF_NOT_OK(method_or_status.status());

std::string response_type_url =
absl::StrCat(Envoy::Grpc::Common::typeUrlPrefix(), "/",
method_or_status.value()->output_type()->full_name());
const Protobuf::Type* response_type = (*type_finder_)(response_type_url);
return response_type;
}

REGISTER_FACTORY(RemoveFilterActionFactory,
Matcher::ActionFactory<ProtoApiScrubberRemoveFieldAction>);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using envoy::extensions::filters::http::proto_api_scrubber::v3::RestrictionConfi
using google::grpc::transcoding::TypeHelper;
using Http::HttpMatchingData;
using Protobuf::Map;
using Protobuf::MethodDescriptor;
using xds::type::matcher::v3::HttpAttributesCelMatchInput;
using ProtoApiScrubberRemoveFieldAction =
envoy::extensions::filters::http::proto_api_scrubber::v3::RemoveFieldAction;
Expand Down Expand Up @@ -82,6 +83,9 @@ class ProtoApiScrubberFilterConfig : public Logger::Loggable<Logger::Id::filter>
// Returns the request type of the method.
absl::StatusOr<const Protobuf::Type*> getRequestType(const std::string& method_name) const;

// Returns the response type of the method.
absl::StatusOr<const Protobuf::Type*> getResponseType(const std::string& method_name) const;

FilteringMode filteringMode() const { return filtering_mode_; }

private:
Expand Down Expand Up @@ -124,6 +128,10 @@ class ProtoApiScrubberFilterConfig : public Logger::Loggable<Logger::Id::filter>
const Map<std::string, RestrictionConfig>& restrictions,
Server::Configuration::FactoryContext& context);

// Returns method descriptor by looking up the `descriptor_pool_`.
// If the method doesn't exist in the `descriptor_pool`, it returns absl::InvalidArgument error.
absl::StatusOr<const MethodDescriptor*> getMethodDescriptor(const std::string& method_name) const;

FilteringMode filtering_mode_;

std::unique_ptr<const Envoy::Protobuf::DescriptorPool> descriptor_pool_;
Expand Down
1 change: 1 addition & 0 deletions test/extensions/filters/http/proto_api_scrubber/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ envoy_cc_test(
"//test/proto:apikeys_proto_cc_proto",
"//test/proto:bookstore_proto_cc_proto",
"//test/test_common:environment_lib",
"//test/test_common:logging_lib",
"//test/test_common:utility_lib",
"@com_google_absl//absl/strings",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,43 @@ TEST_F(ProtoApiScrubberFilterConfigTest, GetRequestType) {
}
}

TEST_F(ProtoApiScrubberFilterConfigTest, GetResponseType) {
// 1. Initialize the config
absl::StatusOr<std::shared_ptr<const ProtoApiScrubberFilterConfig>> config_or_status =
ProtoApiScrubberFilterConfig::create(proto_config_, factory_context_);
ASSERT_EQ(config_or_status.status().code(), absl::StatusCode::kOk);
filter_config_ = std::move(config_or_status.value());

{
// Case 1: Valid Method Name
// The method name passed from headers usually has the format /Package.Service/Method
std::string method_name = "/apikeys.ApiKeys/CreateApiKey";

absl::StatusOr<const Protobuf::Type*> type_or_status =
filter_config_->getResponseType(method_name);

ASSERT_EQ(type_or_status.status().code(), absl::StatusCode::kOk);
ASSERT_NE(type_or_status.value(), nullptr);

// Verify the resolved input type is correct
EXPECT_EQ(type_or_status.value()->name(), "apikeys.ApiKey");
}

{
// Case 2: Invalid Method Name (Not in descriptor)
std::string method_name = "/apikeys.ApiKeys/NonExistentMethod";

absl::StatusOr<const Protobuf::Type*> type_or_status =
filter_config_->getResponseType(method_name);

EXPECT_EQ(type_or_status.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
type_or_status.status().message(),
testing::HasSubstr(
"Unable to find method `apikeys.ApiKeys.NonExistentMethod` in the descriptor pool"));
}
}

TEST_F(ProtoApiScrubberFilterConfigTest, GetTypeFinder) {
absl::StatusOr<std::shared_ptr<const ProtoApiScrubberFilterConfig>> config_or_status =
ProtoApiScrubberFilterConfig::create(proto_config_, factory_context_);
Expand Down
Loading