Skip to content

Commit

Permalink
[carnot] Add Serialize/Deserialize to UDFDefinition (#1441)
Browse files Browse the repository at this point in the history
Summary: In preparation for implementing partial aggregates on the exec
side, this PR adds Serialize/Deserialize methods to the UDADefinition.
This will allow the exec side to execute udas' Serialize and Deserialize
methods in a type erased way, just as we do for Update, Merge, Finalize,
etc.

Relevant Issues: #1440

Type of change: /kind cleanup

Test Plan: Added a test for the SerializeArrow and Deserialize methods.
Also tested as part of broader partial aggregate implmentation.

Signed-off-by: James Bartlett <jamesbartlett@pixielabs.ai>
  • Loading branch information
JamesMBartlett authored Jun 7, 2023
1 parent ba9c08a commit a414a78
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/carnot/udf/udf_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ class UDADefinition : public UDFDefinition {
finalize_value_fn = UDAWrapper<T>::FinalizeValue;

supports_partial_ = UDAWrapper<T>::SupportsPartial;

deserialize_fn_ = UDAWrapper<T>::Deserialize;
serialize_arrow_fn_ = UDAWrapper<T>::SerializeArrow;
return Status::OK();
}

Expand Down Expand Up @@ -245,6 +248,12 @@ class UDADefinition : public UDFDefinition {
Status FinalizeArrow(UDA* uda, FunctionContext* ctx, arrow::ArrayBuilder* output) {
return finalize_arrow_fn_(uda, ctx, output);
}
Status Deserialize(UDA* uda, FunctionContext* ctx, const types::StringValue& serialized) {
return deserialize_fn_(uda, ctx, serialized);
}
Status SerializeArrow(UDA* uda, FunctionContext* ctx, arrow::ArrayBuilder* output) {
return serialize_arrow_fn_(uda, ctx, output);
}

private:
std::vector<types::DataType> init_arguments_;
Expand All @@ -270,6 +279,10 @@ class UDADefinition : public UDFDefinition {
std::function<Status(UDA* uda, FunctionContext* ctx,
const std::vector<std::shared_ptr<types::BaseValueType>>& inputs)>
init_wrapper_fn_;
std::function<Status(UDA* uda, FunctionContext* ctx, const types::StringValue& serialized)>
deserialize_fn_;
std::function<Status(UDA* uda, FunctionContext* ctx, arrow::ArrayBuilder* output)>
serialize_arrow_fn_;
};

class UDTFDefinition : public UDFDefinition {
Expand Down
44 changes: 44 additions & 0 deletions src/carnot/udf/udf_definition_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "src/carnot/udf/udf_definition.h"
#include "src/common/testing/testing.h"
#include "src/shared/types/column_wrapper.h"
#include "src/shared/types/types.h"

namespace px {
namespace carnot {
Expand Down Expand Up @@ -198,6 +199,24 @@ class InitArgUDA : public udf::UDA {
std::vector<std::string> updates_;
};

class SerdeUDA : public UDA {
public:
void Update(FunctionContext*, types::Int64Value v) { sum_ += v.val; }
void Merge(FunctionContext*, const SerdeUDA& other) { sum_ = sum_ + other.sum_; }
Int64Value Finalize(FunctionContext*) { return sum_; }
StringValue Serialize(FunctionContext*) { return absl::StrCat(sum_); }

Status Deserialize(FunctionContext*, const StringValue& data) {
if (!absl::SimpleAtoi(data, &sum_)) {
return Status{statuspb::Code::INVALID_ARGUMENT, "invalid serialized"};
}
return Status::OK();
}

private:
int64_t sum_ = 0;
};

TEST(UDADefinition, without_merge) {
auto ctx = FunctionContext(nullptr, nullptr);
UDADefinition def("minsum");
Expand Down Expand Up @@ -277,6 +296,31 @@ TEST(UDADefinition, init_args) {
EXPECT_EQ("123, init_arg, true, [1, 2, 3]", out);
}

TEST(UDADefinition, serialize_deserialize) {
auto ctx = FunctionContext(nullptr, nullptr);
UDADefinition def("serdeuda");
EXPECT_OK(def.Init<SerdeUDA>());

auto uda = def.Make();
types::Int64ValueColumnWrapper v1({1, 2, 3});
EXPECT_OK(def.ExecBatchUpdate(uda.get(), &ctx, {&v1}));

auto output_builder = std::make_shared<arrow::StringBuilder>();
EXPECT_OK(def.SerializeArrow(uda.get(), &ctx, output_builder.get()));
std::shared_ptr<arrow::Array> ser;
EXPECT_TRUE(output_builder->Finish(&ser).ok());
EXPECT_EQ(1, ser->length());
auto casted = static_cast<arrow::StringArray*>(ser.get());
EXPECT_EQ("6", casted->GetView(0));

auto uda2 = def.Make();
EXPECT_OK(def.Deserialize(uda2.get(), &ctx, "100"));

types::Int64Value out;
EXPECT_OK(def.FinalizeValue(uda2.get(), &ctx, &out));
EXPECT_EQ(100, out.val);
}

} // namespace udf
} // namespace carnot
} // namespace px
54 changes: 54 additions & 0 deletions src/carnot/udf/udf_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "src/common/base/base.h"
#include "src/shared/types/arrow_adapter.h"
#include "src/shared/types/column_wrapper.h"
#include "src/shared/types/types.h"

namespace px {
namespace carnot {
Expand Down Expand Up @@ -435,6 +436,59 @@ struct UDAWrapper {
*casted_output = casted_uda->Finalize(ctx);
return Status::OK();
}

/**
* Call the UDA's Serialize method.
*
* @param uda a pointer to the UDA
* @param ctx The function context.
* @param output An arrow array builder to put the serialized output in.
* @return Status from the uda's serialize function.
*/
template <typename Q = TUDA, std::enable_if_t<UDATraits<Q>::SupportsPartial(), void>* = nullptr>
static Status SerializeArrowImpl(UDA* uda, FunctionContext* ctx, arrow::ArrayBuilder* output) {
auto* casted_builder = static_cast<arrow::StringBuilder*>(output);
auto* casted_uda = static_cast<TUDA*>(uda);
PX_RETURN_IF_ERROR(casted_builder->Append(UnWrap(casted_uda->Serialize(ctx))));
return Status::OK();
}

/**
* Return Status::OK, if the UDA doesn't have a Serialize method.
*/
template <typename Q = TUDA, std::enable_if_t<!UDATraits<Q>::SupportsPartial(), void>* = nullptr>
static Status SerializeArrowImpl(UDA*, FunctionContext*, arrow::ArrayBuilder*) {
return Status::OK();
}
static Status SerializeArrow(UDA* uda, FunctionContext* ctx, arrow::ArrayBuilder* output) {
return SerializeArrowImpl(uda, ctx, output);
}
/**
* Call the UDA's Deserialize method.
*
* @param uda a pointer to the UDA
* @param ctx The function context.
* @param serialized A StringValue holding the data to deserialize.
* @return Status from the uda's deserialize function.
*/
template <typename Q = TUDA, std::enable_if_t<UDATraits<Q>::SupportsPartial(), void>* = nullptr>
static Status DeserializeImpl(UDA* uda, FunctionContext* ctx,
const types::StringValue& serialized) {
auto* casted_uda = static_cast<TUDA*>(uda);
PX_RETURN_IF_ERROR(casted_uda->Deserialize(ctx, serialized));
return Status::OK();
}

/**
* Return Status::OK, if the UDA doesn't have a Deserialize method.
*/
template <typename Q = TUDA, std::enable_if_t<!UDATraits<Q>::SupportsPartial(), void>* = nullptr>
static Status DeserializeImpl(UDA*, FunctionContext*, const types::StringValue&) {
return Status::OK();
}
static Status Deserialize(UDA* uda, FunctionContext* ctx, const types::StringValue& serialized) {
return DeserializeImpl(uda, ctx, serialized);
}
};

/**
Expand Down

0 comments on commit a414a78

Please sign in to comment.