Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support cleaning up spare examples correctly in read_span_flatbuffer() #4684

Merged
merged 12 commits into from
Feb 15, 2024
Merged
3 changes: 3 additions & 0 deletions vowpalwabbit/core/include/vw/core/error_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ ERROR_CODE_DEFINITION(
ERROR_CODE_DEFINITION(
13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ")
ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ")
ERROR_CODE_DEFINITION(15, fb_parser_span_misaligned, "Input Flatbuffer span is not aligned to an 8-byte boundary. ")
ERROR_CODE_DEFINITION(16, fb_parser_span_length_mismatch, "Input Flatbuffer span does not match flatbuffer size prefix. ")


// TODO: This is temporary until we switch to the new error handling mechanism.
ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#pragma once

#include "vw/core/api_status.h"
#include "vw/core/example.h"
#include "vw/core/multi_ex.h"
#include "vw/core/shared_data.h"
Expand All @@ -14,15 +13,22 @@
namespace VW
{

class api_status;
namespace experimental
{
class api_status;
}

using example_sink_f = std::function<void(VW::multi_ex&& spare_examples)>;

namespace parsers
{
namespace flatbuffer
{
int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples);
bool read_span_flatbuffer(
VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples);


int read_span_flatbuffer(
VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples, example_sink_f example_sink = nullptr, VW::experimental::api_status* status = nullptr);
lokitoth marked this conversation as resolved.
Show resolved Hide resolved

class parser
{
Expand Down
48 changes: 32 additions & 16 deletions vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

#include "vw/fb_parser/parse_example_flatbuffer.h"

#include "vw/core/api_status.h"
#include "vw/core/action_score.h"
#include "vw/core/best_constant.h"
#include "vw/core/cb.h"
#include "vw/core/constant.h"
#include "vw/core/error_constants.h"
#include "vw/core/global_data.h"
#include "vw/core/parser.h"
#include "vw/core/scope_exit.h"
#include "vw/core/vw.h"

#include <cfloat>
Expand Down Expand Up @@ -43,9 +45,12 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl
return static_cast<int>(status.get_error_code() == VW::experimental::error_code::success);
}

bool read_span_flatbuffer(
VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples)
int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory,
VW::multi_ex& examples, example_sink_f example_sink, VW::experimental::api_status* status)
{
int a = 0;
lokitoth marked this conversation as resolved.
Show resolved Hide resolved
a++;

// we expect context to contain a size_prefixed flatbuffer (technically a binary string)
// which means:
//
Expand All @@ -59,16 +64,15 @@ bool read_span_flatbuffer(
// thus context.size() = sizeof(length) + length
io_buf unused;

// TODO: How do we report errors out of here? (This is a general API problem with the parsers)
size_t address = reinterpret_cast<size_t>(span);
if (address % 8 != 0)
{
std::stringstream sstream;
sstream << "fb_parser error: flatbuffer data not aligned to 8 bytes" << std::endl;
sstream << " span => @" << std::hex << address << std::dec << " % " << 8 << " = " << address % 8
<< " (vs desired = " << 0 << ")";
THROW(sstream.str());
return false;

RETURN_ERROR_LS(status, fb_parser_span_misaligned) << sstream.str();
lokitoth marked this conversation as resolved.
Show resolved Hide resolved
}

flatbuffers::uoffset_t flatbuffer_object_size =
Expand All @@ -79,17 +83,32 @@ bool read_span_flatbuffer(
sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl;
sstream << " span => @" << std::hex << address << std::dec << " size_prefix = " << flatbuffer_object_size
<< " length = " << length;
THROW(sstream.str());
return false;

RETURN_ERROR_LS(status, fb_parser_span_length_mismatch) << sstream.str();
}

VW::multi_ex temp_ex;
temp_ex.push_back(&example_factory());
auto scope_guard = VW::scope_exit([&temp_ex, &all, &example_sink]()
{
if (example_sink == nullptr) { VW::finish_example(*all, temp_ex); }
else { example_sink(std::move(temp_ex)); }
lokitoth marked this conversation as resolved.
Show resolved Hide resolved
});

// There is a bit of unhappiness with the interface of the read_XYZ_<format>() functions, because they often
// expect the input multi_ex to have a single "empty" example there. This contributes, in part, to the large
// proliferation of entry points into the JSON parser(s). We want to avoid exposing that insofar as possible,
// so we will check whether we already received a perfectly good example and use that, or create a new one if
// needed.
if (examples.size() > 0)
lokitoth marked this conversation as resolved.
Show resolved Hide resolved
lokitoth marked this conversation as resolved.
Show resolved Hide resolved
{
temp_ex.push_back(examples[0]);
examples.pop_back();
}
else { temp_ex.push_back(&example_factory()); }
rajan-chari marked this conversation as resolved.
Show resolved Hide resolved

bool has_more = true;
VW::experimental::api_status status;
do {
switch (all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, &status))
switch (int result = all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, status))
{
case VW::experimental::error_code::success:
has_more = true;
Expand All @@ -98,10 +117,7 @@ bool read_span_flatbuffer(
has_more = false;
break;
default:
std::stringstream sstream;
sstream << "Error parsing examples: " << std::endl;
THROW(sstream.str());
return false;
RETURN_IF_FAIL(result);
}

has_more &= !temp_ex[0]->is_newline;
lokitoth marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -113,8 +129,7 @@ bool read_span_flatbuffer(
}
} while (has_more);

VW::finish_example(*all, temp_ex);
return true;
return VW::experimental::error_code::success;
}

const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; }
Expand Down Expand Up @@ -541,6 +556,7 @@ int parser::parse_flat_label(
break;
}
case Label_NONE:
case Label_no_label:
break;
default:
if (_active_collection && _active_multi_ex)
Expand Down
1 change: 1 addition & 0 deletions vowpalwabbit/fb_parser/src/parse_label.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "vw/core/api_status.h"
rajan-chari marked this conversation as resolved.
Show resolved Hide resolved
#include "vw/core/action_score.h"
#include "vw/core/best_constant.h"
#include "vw/core/cb.h"
Expand Down
86 changes: 86 additions & 0 deletions vowpalwabbit/fb_parser/tests/example_data_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@
#pragma once

#include "flatbuffers/flatbuffers.h"
#include "vw/fb_parser/generated/example_generated.h"

#include "prototype_example.h"
#include "prototype_example_root.h"
#include "prototype_label.h"
#include "prototype_namespace.h"
#include "vw/common/hash.h"
#include "vw/common/random.h"
#include "vw/common/future_compat.h"

#include "vw/core/error_constants.h"

#include <vector>

USE_PROTOTYPE_MNEMONICS_EX

using namespace flatbuffers;
namespace fb = VW::parsers::flatbuffer;

namespace vwtest
{

Expand All @@ -40,8 +48,86 @@ class example_data_generator
prototype_example_collection_t create_simple_log(
uint8_t num_examples, uint8_t numeric_features, uint8_t string_features);

public:
enum NamespaceErrors
{
BAD_NAMESPACE_NO_ERROR = 0,
BAD_NAMESPACE_NAME_HASH_MISSING = 1,
BAD_NAMESPACE_FEATURE_VALUES_MISSING = 2,
BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH = 4,
BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH = 8,
BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING = 16,
};

template <NamespaceErrors errors = NamespaceErrors::BAD_NAMESPACE_NO_ERROR>
Offset<fb::Namespace> create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w);

private:
VW::rand_state rng;
};

template <example_data_generator::NamespaceErrors errors>
Offset<fb::Namespace> example_data_generator::create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w)
{
prototype_namespace_t ns = create_namespace("BadNamespace", 1, 1);
if VW_STD17_CONSTEXPR (errors == NamespaceErrors::BAD_NAMESPACE_NO_ERROR) return ns.create_flatbuffer(builder, w);

constexpr bool include_ns_name_hash = !(errors & NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING);
constexpr bool include_feature_values = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING);

constexpr bool include_feature_hashes = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING);
constexpr bool skip_a_feature_hash = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH);
static_assert(!skip_a_feature_hash || include_feature_hashes, "Cannot skip a feature hash if they are not included");

constexpr bool include_feature_names = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING);
constexpr bool skip_a_feature_name = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH);
static_assert(!skip_a_feature_name || include_feature_names, "Cannot skip a feature name if they are not included");

std::vector<Offset<String>> feature_names;
std::vector<float> feature_values;
std::vector<uint64_t> feature_hashes;

for (size_t i = 0; i < ns.features.size(); i++)
{
const auto& f = ns.features[i];

if VW_STD17_CONSTEXPR (include_feature_names && (!skip_a_feature_name || i == 1))
{
feature_names.push_back(builder.CreateString(f.name));
}

if VW_STD17_CONSTEXPR (include_feature_values) feature_values.push_back(f.value);

if VW_STD17_CONSTEXPR (include_feature_hashes && (!skip_a_feature_hash || i == 0))
{
feature_hashes.push_back(f.hash);
}
}

Offset<String> name_offset = Offset<String>();
if (include_ns_name_hash)
{
name_offset = builder.CreateString(ns.name);
}

// This function attempts to, insofar as possible, generate a layout that looks like it could have
// been created using the normal serialization code: In this case, that means that the strings for
// the feature names are serialized into the builder before a call to CreateNamespaceDirect is made,
// which is where the feature_names vector is allocated.
Offset<Vector<Offset<String>>> feature_names_offset = include_feature_names ? builder.CreateVector(feature_names) : Offset<Vector<Offset<String>>>();
Offset<Vector<float>> feature_values_offset = include_feature_values ? builder.CreateVector(feature_values) : Offset<Vector<float>>();
Offset<Vector<uint64_t>> feature_hashes_offset = include_feature_hashes ? builder.CreateVector(feature_hashes) : Offset<Vector<uint64_t>>();

fb::NamespaceBuilder ns_builder(builder);

if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_full_hash(VW::hash_space(w, ns.name));
if VW_STD17_CONSTEXPR (include_feature_hashes) ns_builder.add_feature_hashes(feature_hashes_offset);
if VW_STD17_CONSTEXPR (include_feature_values) ns_builder.add_feature_values(feature_values_offset);
if VW_STD17_CONSTEXPR (include_feature_names) ns_builder.add_feature_names(feature_names_offset);
if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_name(name_offset);

ns_builder.add_hash(ns.feature_group);
return ns_builder.Finish();
}

} // namespace vwtest
1 change: 1 addition & 0 deletions vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "prototype_namespace.h"
#include "vw/common/future_compat.h"
#include "vw/common/string_view.h"
#include "vw/core/api_status.h"
#include "vw/core/constant.h"
#include "vw/core/error_constants.h"
#include "vw/core/example.h"
Expand Down
6 changes: 6 additions & 0 deletions vowpalwabbit/fb_parser/tests/prototype_typemappings.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@ template <>
struct fb_type<prototype_example_t>
{
using type = VW::parsers::flatbuffer::Example;

constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_Example;
};

template <>
struct fb_type<prototype_multiexample_t>
{
using type = VW::parsers::flatbuffer::MultiExample;

constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_MultiExample;
};

template <>
struct fb_type<prototype_example_collection_t>
{
using type = VW::parsers::flatbuffer::ExampleCollection;

constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_ExampleCollection;
};

using union_t = void;
Expand Down
Loading
Loading