Skip to content

Commit

Permalink
[carnot] Output column per UDA in a partial aggregate (pixie-io#1442)
Browse files Browse the repository at this point in the history
Summary: Prior to this change, the compiler would output a single column
called `serialized_expressions` for any partial agg. This isn't
particularly convenient for the execution side. Instead, the compiler
will now output a column per UDA involved in the agg node. For example,
if you have an agg node that has a `min` and `max` output the partial
agg will now output two columns (plus any group columns) called
`serialized_min` and `serialized_max`, where before it would output an
unspecified `serialized_expressions` column.

Relevant Issues: pixie-io#1440

Type of change: /kind cleanup

Test Plan: Modified existing tests to test the new behaviour. Also
tested as part of broader partial aggregate changes.

Signed-off-by: James Bartlett <jamesbartlett@pixielabs.ai>
  • Loading branch information
JamesMBartlett authored Jun 7, 2023
1 parent a414a78 commit 495311c
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 27 deletions.
26 changes: 13 additions & 13 deletions src/carnot/plan/operators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <absl/strings/str_join.h>
#include <absl/strings/substitute.h>
#include <google/protobuf/text_format.h>
#include <magic_enum.hpp>

#include "src/carnot/plan/scalar_expression.h"
Expand Down Expand Up @@ -193,9 +194,12 @@ std::string AggregateOperator::DebugString() const {
const auto& g = groups();
std::vector<std::string> group_names(g.size());
std::transform(begin(g), end(g), begin(group_names), [](auto val) { return val.name; });

return absl::Substitute("Op:Aggregate(values=($0), groups=($1))",
absl::StrJoin(value_names, ", "), absl::StrJoin(group_names, ", "));
std::string out;
::google::protobuf::TextFormat::PrintToString(pb_, &out);
return absl::Substitute(
"Op:Aggregate(values=($0), groups=($1), partial=($2), finalize=($3)):\n$4",
absl::StrJoin(value_names, ", "), absl::StrJoin(group_names, ", "), partial_agg(),
finalize_results(), out);
}

Status AggregateOperator::Init(const planpb::AggregateOperator& pb) {
Expand Down Expand Up @@ -251,17 +255,13 @@ StatusOr<table_store::schema::Relation> AggregateOperator::OutputRelation(
output_relation.AddColumn(input_relation.GetColumnType(col_idx), pb_.group_names(idx));
}

// If this node is a partial aggregate we output a simple schema where the last column has
// serialized aggregates.
// TODO(philkuz) need the column name and maybe type from somewhere else.
if (pb_.partial_agg() && !pb_.finalize_results()) {
output_relation.AddColumn(types::STRING, "serialized_expressions");
return output_relation;
}

for (const auto& [i, value] : Enumerate(values_)) {
PX_ASSIGN_OR_RETURN(auto dt, value->OutputDataType(state, schema));
output_relation.AddColumn(dt, pb_.value_names(i));
if (pb_.finalize_results()) {
PX_ASSIGN_OR_RETURN(auto dt, value->OutputDataType(state, schema));
output_relation.AddColumn(dt, pb_.value_names(i));
} else {
output_relation.AddColumn(types::STRING, pb_.value_names(i));
}
}
return output_relation;
}
Expand Down
2 changes: 2 additions & 0 deletions src/carnot/plan/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ class AggregateOperator : public Operator {
const std::vector<GroupInfo>& groups() const { return groups_; }
const std::vector<std::shared_ptr<AggregateExpression>>& values() const { return values_; }
bool windowed() const { return pb_.windowed(); }
bool partial_agg() const { return pb_.partial_agg(); }
bool finalize_results() const { return pb_.finalize_results(); }

private:
std::vector<std::shared_ptr<AggregateExpression>> values_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ StatusOr<OperatorIR*> AggOperatorMgr::CreatePrepareOperator(IR* plan, OperatorIR
new_type->AddColumn(group->col_name(), group->resolved_type());
}

// Add column for the serialized expression
new_type->AddColumn("serialized_expressions", ValueType::Create(types::STRING, types::ST_NONE));
for (const auto& col_expr : agg->aggregate_expressions()) {
new_type->AddColumn("serialized_" + col_expr.name,
ValueType::Create(types::STRING, types::ST_NONE));
}
PX_RETURN_IF_ERROR(new_agg->SetResolvedType(new_type));

DCHECK(Match(new_agg, PartialAgg()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ TEST_F(PartialOpMgrTest, agg_test) {
auto service_col = MakeColumn("service", 0);
EXPECT_OK(service_col->SetResolvedType(ValueType::Create(types::STRING, types::ST_NONE)));
auto mean_func = MakeMeanFunc(MakeColumn("count", 0));
auto agg = MakeBlockingAgg(mem_src, {count_col, service_col}, {{"mean", mean_func}});
Relation agg_relation({types::INT64, types::STRING, types::FLOAT64},
{"count", "service", "mean"});
auto agg = MakeBlockingAgg(mem_src, {count_col, service_col},
{{"mean", mean_func}, {"mean2", mean_func}});
Relation agg_relation({types::INT64, types::STRING, types::FLOAT64, types::FLOAT64},
{"count", "service", "mean", "mean2"});
MakeMemSink(agg, "out");

ResolveTypesRule type_rule(compiler_state_.get());
Expand Down Expand Up @@ -118,8 +119,8 @@ TEST_F(PartialOpMgrTest, agg_test) {
}
// Confirm that the relations are good.
EXPECT_THAT(*prepare_agg->resolved_table_type(),
IsTableType(Relation({types::INT64, types::STRING, types::STRING},
{"count", "service", "serialized_expressions"})));
IsTableType(Relation({types::INT64, types::STRING, types::STRING, types::STRING},
{"count", "service", "serialized_mean", "serialized_mean2"})));

EXPECT_THAT(*merge_agg->resolved_table_type(), IsTableType(agg_relation));
}
Expand Down
10 changes: 4 additions & 6 deletions src/carnot/planner/distributed/splitter/splitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,10 @@ TEST_F(SplitterTest, partial_agg_test) {
EXPECT_EQ(grpc_sink->destination_id(), grpc_source->source_id());

// Confirm that the relations have serialized in their relation.
EXPECT_THAT(
*grpc_sink->resolved_table_type(),
IsTableType(Relation({types::INT64, types::STRING}, {"count", "serialized_expressions"})));
EXPECT_THAT(
*grpc_source->resolved_table_type(),
IsTableType(Relation({types::INT64, types::STRING}, {"count", "serialized_expressions"})));
EXPECT_THAT(*grpc_sink->resolved_table_type(),
IsTableType(Relation({types::INT64, types::STRING}, {"count", "serialized_mean"})));
EXPECT_THAT(*grpc_source->resolved_table_type(),
IsTableType(Relation({types::INT64, types::STRING}, {"count", "serialized_mean"})));

// Verify that the aggregate connects back into the original group.
ASSERT_EQ(finalize_agg->Children().size(), 1);
Expand Down
2 changes: 1 addition & 1 deletion src/carnot/planner/ir/blocking_agg_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Status BlockingAggIR::EvaluateAggregateExpression(planpb::AggregateExpression* e

Status BlockingAggIR::ToProto(planpb::Operator* op) const {
auto pb = op->mutable_agg_op();
if (finalize_results_ && !partial_agg_) {
if (!partial_agg_) {
(*pb->mutable_values()) = pre_split_proto_.values();
(*pb->mutable_value_names()) = pre_split_proto_.value_names();
} else {
Expand Down
4 changes: 4 additions & 0 deletions src/carnot/planpb/test_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ groups {
}
group_names: "group1"
value_names: "value1"
partial_agg: true
finalize_results: true
)";

constexpr char kWindowedAggOperator1[] = R"(
Expand All @@ -232,6 +234,8 @@ groups {
}
group_names: "group1"
value_names: "value1"
partial_agg: true
finalize_results: true
)";

constexpr char kFilterOperator1[] = R"(
Expand Down

0 comments on commit 495311c

Please sign in to comment.