Skip to content

Commit

Permalink
Merge pull request redpanda-data#23469 from rockwotj/proto-parse
Browse files Browse the repository at this point in the history
  • Loading branch information
rockwotj authored Sep 25, 2024
2 parents cb2f56e + bc4463a commit 96a365f
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 38 deletions.
2 changes: 2 additions & 0 deletions src/v/serde/protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ redpanda_cc_library(
include_prefix = "serde/protobuf",
visibility = ["//visibility:public"],
deps = [
"//src/v/bytes:hash",
"//src/v/bytes:iobuf",
"//src/v/bytes:iobuf_parser",
"//src/v/container:chunked_hash_map",
"//src/v/container:fragmented_vector",
"//src/v/utils:vint",
"@protobuf",
"@seastar",
],
)
94 changes: 75 additions & 19 deletions src/v/serde/protobuf/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "bytes/iobuf_parser.h"
#include "utils/vint.h"

#include <seastar/util/variant_utils.hh>

#include <google/protobuf/descriptor.h>

#include <algorithm>
Expand Down Expand Up @@ -298,14 +300,73 @@ class parser {
}
}

template<typename T>
void update_field(int32_t field_number, T&& value) {
void
update_field(int32_t field_number, std::unique_ptr<parsed::message> value) {
auto* field = current_->descriptor->FindFieldByNumber(field_number);
if (field == nullptr) {
throw std::runtime_error(
fmt::format("unknown field number: {}", field_number));
}
if (field->is_map()) {
convert_to_map_and_update(*field, std::move(value));
} else {
update_field(*field, std::forward<T>(value));
update_field(*field, std::move(value));
}
}

void convert_to_map_and_update(
const pb::FieldDescriptor& field,
std::unique_ptr<parsed::message> value) {
const auto* key_descriptor = field.message_type()->map_key();
parsed::map::key key = std::monostate{};
auto it = value->fields.find(key_descriptor->number());
if (it != value->fields.end()) {
key = ss::visit(
it->second,
[](double val) -> parsed::map::key { return val; },
[](float val) -> parsed::map::key { return val; },
[](int32_t val) -> parsed::map::key { return val; },
[](int64_t val) -> parsed::map::key { return val; },
[](uint32_t val) -> parsed::map::key { return val; },
[](uint64_t val) -> parsed::map::key { return val; },
[](bool val) -> parsed::map::key { return val; },
[](iobuf& val) -> parsed::map::key { return std::move(val); },
[](const auto&) -> parsed::map::key {
throw std::runtime_error(
"invariant: unable to convert type to map key");
});
}
const auto* val_descriptor = field.message_type()->map_value();
parsed::map::value val = std::monostate{};
it = value->fields.find(val_descriptor->number());
if (it != value->fields.end()) {
val = ss::visit(
it->second,
[](double val) -> parsed::map::value { return val; },
[](float val) -> parsed::map::value { return val; },
[](int32_t val) -> parsed::map::value { return val; },
[](int64_t val) -> parsed::map::value { return val; },
[](uint32_t val) -> parsed::map::value { return val; },
[](uint64_t val) -> parsed::map::value { return val; },
[](bool val) -> parsed::map::value { return val; },
[](iobuf& val) -> parsed::map::value { return std::move(val); },
[](std::unique_ptr<parsed::message>& val) -> parsed::map::value {
return std::move(val);
},
[](const auto&) -> parsed::map::value {
throw std::runtime_error(
"invariant: unable to convert type to map value");
});
}

it = current_->message->fields.find(field.number());
if (it == current_->message->fields.end()) {
parsed::map map;
map.entries.emplace(std::move(key), std::move(val));
current_->message->fields.emplace(field.number(), std::move(map));
} else {
auto& map = std::get<parsed::map>(it->second);
map.entries.emplace(std::move(key), std::move(val));
}
}

Expand Down Expand Up @@ -337,18 +398,16 @@ class parser {
}

/**
* Update a repeated field. If T is itself `repeated` then concat the two
* sequences, otherwise append the last value.
* Update a repeated field. If T is itself `repeated` then concat the
* two sequences, otherwise append the last value.
*
* field_number must be of a repeated type.
*/
template<typename T>
void append_repeated_field(int32_t field_number, T&& value) {
auto it = current_->message->fields.find(field_number);
if (it == current_->message->fields.end()) {
if constexpr (
std::is_same_v<T, parsed::repeated>
|| std::is_same_v<T, parsed::map>) {
if constexpr (std::is_same_v<T, parsed::repeated>) {
current_->message->fields.emplace(
field_number, std::forward<T>(value));
} else {
Expand All @@ -358,11 +417,6 @@ class parser {
current_->message->fields.emplace(
field_number, std::move(repeated));
}
} else if constexpr (std::is_same_v<T, parsed::map>) {
auto& map = std::get<parsed::map>(it->second);
for (auto& [key, val] : value) {
map.entries.insert_or_assign(std::move(key), std::move(val));
}
} else {
auto& repeated = std::get<parsed::repeated>(it->second);
if constexpr (std::is_same_v<T, parsed::repeated>) {
Expand All @@ -385,8 +439,8 @@ class parser {

void stage_message(
int32_t field_number, iobuf iobuf, const pb::Descriptor& descriptor) {
// Matches the golang max message depth (the highest of all runtimes as
// far as I understand it).
// Matches the golang max message depth (the highest of all runtimes
// as far as I understand it).
constexpr size_t max_nested_message_depth = 10000;
if (state_.size() >= max_nested_message_depth) {
throw std::runtime_error("max nested message depth reached");
Expand Down Expand Up @@ -422,7 +476,8 @@ class parser {
chunked_vector<std::invoke_result_t<decltype(reader)>> vec;
if (amount > current_->parser.bytes_left()) {
throw std::runtime_error(fmt::format(
"invalid packed field field: (bytes_needed={}, bytes_left={})",
"invalid packed field field: (bytes_needed={}, "
"bytes_left={})",
amount,
current_->parser.bytes_left()));
}
Expand Down Expand Up @@ -545,10 +600,11 @@ class parser {
}

// The state of the current message being parsed. As we encounter new
// messages we will push an entry onto state_ so that we are not stack bound
// with respect to protobuf nested message depth.
// messages we will push an entry onto state_ so that we are not stack
// bound with respect to protobuf nested message depth.
struct state {
// The field_number of the proto to push back onto it's parent message.
// The field_number of the proto to push back onto it's parent
// message.
int32_t field_number;
// The parser for this current protobuf only
iobuf_parser parser;
Expand Down
16 changes: 14 additions & 2 deletions src/v/serde/protobuf/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
* by the Apache License, Version 2.0
*/

#pragma once

#include "bytes/hash.h"
#include "bytes/iobuf.h"
#include "container/chunked_hash_map.h"
#include "container/fragmented_vector.h"
Expand Down Expand Up @@ -46,10 +49,19 @@ struct repeated {
* A dynamic representation of a proto3 map.
*/
struct map {
using key = std::
variant<double, float, int32_t, int64_t, uint32_t, uint64_t, bool, iobuf>;
using key = std::variant<
std::monostate, // Can be unset
double,
float,
int32_t,
int64_t,
uint32_t,
uint64_t,
bool,
iobuf>;

using value = std::variant<
std::monostate, // Can be unset
double,
float,
int32_t, // Can be an enum value
Expand Down
114 changes: 97 additions & 17 deletions src/v/serde/protobuf/tests/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "bytes/iobuf.h"
#include "bytes/iobuf_parser.h"
#include "gtest/gtest.h"
#include "random/generators.h"
#include "serde/protobuf/parser.h"
// TODO: Fix bazelbuild/bazel#4446
#include "src/v/serde/protobuf/tests/test_messages_edition2023.pb.h"
Expand All @@ -33,6 +32,7 @@
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <variant>

namespace serde::pb {
namespace {
Expand Down Expand Up @@ -101,6 +101,15 @@ class ProtobufParserFixture : public testing::Test {
return output;
}

std::unique_ptr<parsed::message> parse_raw(std::string_view txtpb) {
auto expected = construct_message(txtpb);
iobuf out;
auto binpb = expected->SerializeAsString();
out.append(binpb.data(), binpb.size());
return ::serde::pb::parse(std::move(out), *expected->GetDescriptor())
.get();
}

testing::AssertionResult parse(std::string_view txtpb) {
auto expected = construct_message(txtpb);
iobuf out;
Expand Down Expand Up @@ -232,8 +241,10 @@ class ProtobufParserFixture : public testing::Test {
reflect->SetString(output, field, std::move(str));
},
[=, this](const parsed::map& v) {
auto pbmap = convert_to_protobuf(v, field->message_type());
reflect->SetAllocatedMessage(output, pbmap.release(), field);
auto entries = convert_to_protobuf(v, field->message_type());
for (auto& entry : entries) {
reflect->AddAllocatedMessage(output, field, entry.release());
}
});
}

Expand Down Expand Up @@ -301,34 +312,58 @@ class ProtobufParserFixture : public testing::Test {
});
}

std::unique_ptr<pb::Message>
std::vector<std::unique_ptr<pb::Message>>
convert_to_protobuf(const parsed::map& map, const pb::Descriptor* desc) {
auto output = std::unique_ptr<pb::Message>(
factory_.GetPrototype(desc)->New());
std::vector<std::unique_ptr<pb::Message>> entries;
for (const auto& [k, v] : map.entries) {
auto k_field = ss::visit(
auto output = std::unique_ptr<pb::Message>(
factory_.GetPrototype(desc)->New());
auto key_field = desc->map_key();
ss::visit(
k,
[](const iobuf& v) { return parsed::message::field(v.copy()); },
[](const auto& v) { return parsed::message::field(v); });
convert_to_protobuf(1, k_field, output.get());
[](const std::monostate&) {},
[&, this](const iobuf& v) {
convert_to_protobuf(
key_field->number(),
parsed::message::field(v.copy()),
output.get());
},
[&, this](const auto& v) {
convert_to_protobuf(
key_field->number(),
parsed::message::field(v),
output.get());
});
auto value_field = desc->map_value();
ss::visit(
v,
[](const std::monostate&) {},
[&, this](const iobuf& v) {
convert_to_protobuf(
2, parsed::message::field(v.copy()), output.get());
value_field->number(),
parsed::message::field(v.copy()),
output.get());
},
[&](const std::unique_ptr<parsed::message>& v) {
auto field = desc->FindFieldByNumber(2);
auto msg = convert_to_protobuf(*v, field->message_type());
[&, this](const std::unique_ptr<parsed::message>& v) {
if (value_field->message_type() == nullptr) {
throw std::runtime_error(fmt::format(
"expected message type got: {}",
value_field->DebugString()));
}
auto msg = convert_to_protobuf(
*v, value_field->message_type());
output->GetReflection()->SetAllocatedMessage(
output.get(), msg.release(), field);
output.get(), msg.release(), value_field);
},
[&, this](const auto& v) {
convert_to_protobuf(
2, parsed::message::field(v), output.get());
value_field->number(),
parsed::message::field(v),
output.get());
});
entries.push_back(std::move(output));
}
return output;
return entries;
}

pb::DynamicMessageFactory factory_;
Expand Down Expand Up @@ -490,6 +525,51 @@ optional_sfixed32: -1
)"));
}

TEST_F(ProtobufParserFixture, Map) {
auto msg = parse_raw(R"(
# proto-file: three.proto
# proto-message: Map
meta {
key: "foo"
value: "bar"
}
meta {
key: "baz"
value: "qux"
}
meta {
value: "nokey"
}
meta {
key: "novalue"
}
)");
ASSERT_EQ(msg->fields.size(), 1);
ASSERT_TRUE(msg->fields.contains(1));
const auto& field = msg->fields[1];
ASSERT_TRUE(std::holds_alternative<parsed::map>(field));
const auto& map = std::get<parsed::map>(field);
EXPECT_EQ(map.entries.size(), 4);
EXPECT_TRUE(parse(R"(
# proto-file: three.proto
# proto-message: Entries
entry {
key: "foo"
value: 42
}
entry {
key: "baz"
value: 53
}
entry {
key: "novalue"
}
entry {
value: 99
}
)"));
}

TEST_F(ProtobufParserFixture, RandomData) {
protobuf_test_messages::editions::TestAllTypesEdition2023 kitchen_sink;
constexpr size_t size = 512;
Expand Down
8 changes: 8 additions & 0 deletions src/v/serde/protobuf/tests/three.proto
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,11 @@ message Version4 {
bool data = 3;
}
}

message Map {
map<string, string> meta = 1;
}

message Entries {
map<string, int32> entry = 1;
}

0 comments on commit 96a365f

Please sign in to comment.