diff --git a/src/v/serde/protobuf/BUILD b/src/v/serde/protobuf/BUILD index 6925de78313e..603714621cbd 100644 --- a/src/v/serde/protobuf/BUILD +++ b/src/v/serde/protobuf/BUILD @@ -11,11 +11,13 @@ redpanda_cc_library( include_prefix = "serde/protobuf", visibility = ["//visibility:public"], deps = [ + "//src/v/bytes:hash", "//src/v/bytes:iobuf", "//src/v/bytes:iobuf_parser", "//src/v/container:chunked_hash_map", "//src/v/container:fragmented_vector", "//src/v/utils:vint", "@protobuf", + "@seastar", ], ) diff --git a/src/v/serde/protobuf/parser.cc b/src/v/serde/protobuf/parser.cc index 8dfdf1235d37..90fa99ae93bd 100644 --- a/src/v/serde/protobuf/parser.cc +++ b/src/v/serde/protobuf/parser.cc @@ -14,6 +14,8 @@ #include "bytes/iobuf_parser.h" #include "utils/vint.h" +#include + #include #include @@ -298,14 +300,73 @@ class parser { } } - template - void update_field(int32_t field_number, T&& value) { + void + update_field(int32_t field_number, std::unique_ptr value) { auto* field = current_->descriptor->FindFieldByNumber(field_number); if (field == nullptr) { throw std::runtime_error( fmt::format("unknown field number: {}", field_number)); + } + if (field->is_map()) { + convert_to_map_and_update(*field, std::move(value)); } else { - update_field(*field, std::forward(value)); + update_field(*field, std::move(value)); + } + } + + void convert_to_map_and_update( + const pb::FieldDescriptor& field, + std::unique_ptr value) { + const auto* key_descriptor = field.message_type()->map_key(); + parsed::map::key key = std::monostate{}; + auto it = value->fields.find(key_descriptor->number()); + if (it != value->fields.end()) { + key = ss::visit( + it->second, + [](double val) -> parsed::map::key { return val; }, + [](float val) -> parsed::map::key { return val; }, + [](int32_t val) -> parsed::map::key { return val; }, + [](int64_t val) -> parsed::map::key { return val; }, + [](uint32_t val) -> parsed::map::key { return val; }, + [](uint64_t val) -> parsed::map::key { return val; }, + [](bool val) -> parsed::map::key { return val; }, + [](iobuf& val) -> parsed::map::key { return std::move(val); }, + [](const auto&) -> parsed::map::key { + throw std::runtime_error( + "invariant: unable to convert type to map key"); + }); + } + const auto* val_descriptor = field.message_type()->map_value(); + parsed::map::value val = std::monostate{}; + it = value->fields.find(val_descriptor->number()); + if (it != value->fields.end()) { + val = ss::visit( + it->second, + [](double val) -> parsed::map::value { return val; }, + [](float val) -> parsed::map::value { return val; }, + [](int32_t val) -> parsed::map::value { return val; }, + [](int64_t val) -> parsed::map::value { return val; }, + [](uint32_t val) -> parsed::map::value { return val; }, + [](uint64_t val) -> parsed::map::value { return val; }, + [](bool val) -> parsed::map::value { return val; }, + [](iobuf& val) -> parsed::map::value { return std::move(val); }, + [](std::unique_ptr& val) -> parsed::map::value { + return std::move(val); + }, + [](const auto&) -> parsed::map::value { + throw std::runtime_error( + "invariant: unable to convert type to map value"); + }); + } + + it = current_->message->fields.find(field.number()); + if (it == current_->message->fields.end()) { + parsed::map map; + map.entries.emplace(std::move(key), std::move(val)); + current_->message->fields.emplace(field.number(), std::move(map)); + } else { + auto& map = std::get(it->second); + map.entries.emplace(std::move(key), std::move(val)); } } @@ -337,8 +398,8 @@ class parser { } /** - * Update a repeated field. If T is itself `repeated` then concat the two - * sequences, otherwise append the last value. + * Update a repeated field. If T is itself `repeated` then concat the + * two sequences, otherwise append the last value. * * field_number must be of a repeated type. */ @@ -346,9 +407,7 @@ class parser { void append_repeated_field(int32_t field_number, T&& value) { auto it = current_->message->fields.find(field_number); if (it == current_->message->fields.end()) { - if constexpr ( - std::is_same_v - || std::is_same_v) { + if constexpr (std::is_same_v) { current_->message->fields.emplace( field_number, std::forward(value)); } else { @@ -358,11 +417,6 @@ class parser { current_->message->fields.emplace( field_number, std::move(repeated)); } - } else if constexpr (std::is_same_v) { - auto& map = std::get(it->second); - for (auto& [key, val] : value) { - map.entries.insert_or_assign(std::move(key), std::move(val)); - } } else { auto& repeated = std::get(it->second); if constexpr (std::is_same_v) { @@ -385,8 +439,8 @@ class parser { void stage_message( int32_t field_number, iobuf iobuf, const pb::Descriptor& descriptor) { - // Matches the golang max message depth (the highest of all runtimes as - // far as I understand it). + // Matches the golang max message depth (the highest of all runtimes + // as far as I understand it). constexpr size_t max_nested_message_depth = 10000; if (state_.size() >= max_nested_message_depth) { throw std::runtime_error("max nested message depth reached"); @@ -422,7 +476,8 @@ class parser { chunked_vector> vec; if (amount > current_->parser.bytes_left()) { throw std::runtime_error(fmt::format( - "invalid packed field field: (bytes_needed={}, bytes_left={})", + "invalid packed field field: (bytes_needed={}, " + "bytes_left={})", amount, current_->parser.bytes_left())); } @@ -545,10 +600,11 @@ class parser { } // The state of the current message being parsed. As we encounter new - // messages we will push an entry onto state_ so that we are not stack bound - // with respect to protobuf nested message depth. + // messages we will push an entry onto state_ so that we are not stack + // bound with respect to protobuf nested message depth. struct state { - // The field_number of the proto to push back onto it's parent message. + // The field_number of the proto to push back onto it's parent + // message. int32_t field_number; // The parser for this current protobuf only iobuf_parser parser; diff --git a/src/v/serde/protobuf/parser.h b/src/v/serde/protobuf/parser.h index d3550d242181..036519437dd7 100644 --- a/src/v/serde/protobuf/parser.h +++ b/src/v/serde/protobuf/parser.h @@ -9,6 +9,9 @@ * by the Apache License, Version 2.0 */ +#pragma once + +#include "bytes/hash.h" #include "bytes/iobuf.h" #include "container/chunked_hash_map.h" #include "container/fragmented_vector.h" @@ -46,10 +49,19 @@ struct repeated { * A dynamic representation of a proto3 map. */ struct map { - using key = std:: - variant; + using key = std::variant< + std::monostate, // Can be unset + double, + float, + int32_t, + int64_t, + uint32_t, + uint64_t, + bool, + iobuf>; using value = std::variant< + std::monostate, // Can be unset double, float, int32_t, // Can be an enum value diff --git a/src/v/serde/protobuf/tests/parser_test.cc b/src/v/serde/protobuf/tests/parser_test.cc index af9333cdc867..5a0acb07001e 100644 --- a/src/v/serde/protobuf/tests/parser_test.cc +++ b/src/v/serde/protobuf/tests/parser_test.cc @@ -13,7 +13,6 @@ #include "bytes/iobuf.h" #include "bytes/iobuf_parser.h" #include "gtest/gtest.h" -#include "random/generators.h" #include "serde/protobuf/parser.h" // TODO: Fix bazelbuild/bazel#4446 #include "src/v/serde/protobuf/tests/test_messages_edition2023.pb.h" @@ -33,6 +32,7 @@ #include #include #include +#include namespace serde::pb { namespace { @@ -101,6 +101,15 @@ class ProtobufParserFixture : public testing::Test { return output; } + std::unique_ptr parse_raw(std::string_view txtpb) { + auto expected = construct_message(txtpb); + iobuf out; + auto binpb = expected->SerializeAsString(); + out.append(binpb.data(), binpb.size()); + return ::serde::pb::parse(std::move(out), *expected->GetDescriptor()) + .get(); + } + testing::AssertionResult parse(std::string_view txtpb) { auto expected = construct_message(txtpb); iobuf out; @@ -232,8 +241,10 @@ class ProtobufParserFixture : public testing::Test { reflect->SetString(output, field, std::move(str)); }, [=, this](const parsed::map& v) { - auto pbmap = convert_to_protobuf(v, field->message_type()); - reflect->SetAllocatedMessage(output, pbmap.release(), field); + auto entries = convert_to_protobuf(v, field->message_type()); + for (auto& entry : entries) { + reflect->AddAllocatedMessage(output, field, entry.release()); + } }); } @@ -301,34 +312,58 @@ class ProtobufParserFixture : public testing::Test { }); } - std::unique_ptr + std::vector> convert_to_protobuf(const parsed::map& map, const pb::Descriptor* desc) { - auto output = std::unique_ptr( - factory_.GetPrototype(desc)->New()); + std::vector> entries; for (const auto& [k, v] : map.entries) { - auto k_field = ss::visit( + auto output = std::unique_ptr( + factory_.GetPrototype(desc)->New()); + auto key_field = desc->map_key(); + ss::visit( k, - [](const iobuf& v) { return parsed::message::field(v.copy()); }, - [](const auto& v) { return parsed::message::field(v); }); - convert_to_protobuf(1, k_field, output.get()); + [](const std::monostate&) {}, + [&, this](const iobuf& v) { + convert_to_protobuf( + key_field->number(), + parsed::message::field(v.copy()), + output.get()); + }, + [&, this](const auto& v) { + convert_to_protobuf( + key_field->number(), + parsed::message::field(v), + output.get()); + }); + auto value_field = desc->map_value(); ss::visit( v, + [](const std::monostate&) {}, [&, this](const iobuf& v) { convert_to_protobuf( - 2, parsed::message::field(v.copy()), output.get()); + value_field->number(), + parsed::message::field(v.copy()), + output.get()); }, - [&](const std::unique_ptr& v) { - auto field = desc->FindFieldByNumber(2); - auto msg = convert_to_protobuf(*v, field->message_type()); + [&, this](const std::unique_ptr& v) { + if (value_field->message_type() == nullptr) { + throw std::runtime_error(fmt::format( + "expected message type got: {}", + value_field->DebugString())); + } + auto msg = convert_to_protobuf( + *v, value_field->message_type()); output->GetReflection()->SetAllocatedMessage( - output.get(), msg.release(), field); + output.get(), msg.release(), value_field); }, [&, this](const auto& v) { convert_to_protobuf( - 2, parsed::message::field(v), output.get()); + value_field->number(), + parsed::message::field(v), + output.get()); }); + entries.push_back(std::move(output)); } - return output; + return entries; } pb::DynamicMessageFactory factory_; @@ -490,6 +525,51 @@ optional_sfixed32: -1 )")); } +TEST_F(ProtobufParserFixture, Map) { + auto msg = parse_raw(R"( +# proto-file: three.proto +# proto-message: Map +meta { + key: "foo" + value: "bar" +} +meta { + key: "baz" + value: "qux" +} +meta { + value: "nokey" +} +meta { + key: "novalue" +} +)"); + ASSERT_EQ(msg->fields.size(), 1); + ASSERT_TRUE(msg->fields.contains(1)); + const auto& field = msg->fields[1]; + ASSERT_TRUE(std::holds_alternative(field)); + const auto& map = std::get(field); + EXPECT_EQ(map.entries.size(), 4); + EXPECT_TRUE(parse(R"( +# proto-file: three.proto +# proto-message: Entries +entry { + key: "foo" + value: 42 +} +entry { + key: "baz" + value: 53 +} +entry { + key: "novalue" +} +entry { + value: 99 +} +)")); +} + TEST_F(ProtobufParserFixture, RandomData) { protobuf_test_messages::editions::TestAllTypesEdition2023 kitchen_sink; constexpr size_t size = 512; diff --git a/src/v/serde/protobuf/tests/three.proto b/src/v/serde/protobuf/tests/three.proto index 3e55d67b1d4f..1c3c141e0221 100644 --- a/src/v/serde/protobuf/tests/three.proto +++ b/src/v/serde/protobuf/tests/three.proto @@ -76,3 +76,11 @@ message Version4 { bool data = 3; } } + +message Map { + map meta = 1; +} + +message Entries { + map entry = 1; +}