Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add STDEV() aggregate function #1553

Merged
merged 14 commits into from
Nov 13, 2024
1 change: 0 additions & 1 deletion src/engine/GroupBy.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ class GroupBy : public Operation {
MAX,
SUM,
GROUP_CONCAT,
STDEV,
SAMPLE
};

Expand Down
1 change: 1 addition & 0 deletions src/engine/sparqlExpressions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_library(sparqlExpressions
SampleExpression.cpp
RelationalExpressions.cpp
AggregateExpression.cpp
StdevExpression.cpp
RegexExpression.cpp
NumericUnaryExpressions.cpp
NumericBinaryExpressions.cpp
Expand Down
74 changes: 74 additions & 0 deletions src/engine/sparqlExpressions/StdevExpression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2024, University of Freiburg,
// Chair of Algorithms and Data Structures.
// Author: Christoph Ullinger <ullingec@informatik.uni-freiburg.de>

#include "engine/sparqlExpressions/StdevExpression.h"

namespace sparqlExpression {

namespace detail {

// _____________________________________________________________________________
ExpressionResult DeviationExpression::evaluate(
EvaluationContext* context) const {
auto impl = [context](SingleExpressionResult auto&& el) -> ExpressionResult {
// Prepare space for result
VectorWithMemoryLimit<IdOrLiteralOrIri> exprResult{context->_allocator};
std::fill_n(std::back_inserter(exprResult), context->size(),
IdOrLiteralOrIri{Id::makeUndefined()});
bool undef = false;

auto devImpl = [&undef, &exprResult, context](auto generator) {
double sum = 0.0;
// Intermediate storage of the results returned from the child
// expression
VectorWithMemoryLimit<double> childResults{context->_allocator};

// Collect values as doubles
for (auto& inp : generator) {
const auto& n = detail::NumericValueGetter{}(std::move(inp), context);
auto v = std::visit(
[]<typename T>(T&& value) -> std::optional<double> {
if constexpr (ad_utility::isSimilar<T, double> ||
ad_utility::isSimilar<T, int64_t>) {
return static_cast<double>(value);
} else {
return std::nullopt;
}
},
n);
if (v.has_value()) {
childResults.push_back(v.value());
sum += v.value();
} else {
// There is a non-numeric value in the input. Therefore the entire
// result will be undef.
undef = true;
return;
}
context->cancellationHandle_->throwIfCancelled();
}

// Calculate squared deviation and save for result
double avg = sum / static_cast<double>(context->size());
for (size_t i = 0; i < childResults.size(); i++) {
exprResult.at(i) = IdOrLiteralOrIri{
ValueId::makeFromDouble(std::pow(childResults.at(i) - avg, 2))};
}
};

auto generator =
detail::makeGenerator(AD_FWD(el), context->size(), context);
devImpl(std::move(generator));

if (undef) {
return IdOrLiteralOrIri{Id::makeUndefined()};
}
return exprResult;
};
auto childRes = child_->evaluate(context);
return std::visit(impl, std::move(childRes));
};

} // namespace detail
} // namespace sparqlExpression
63 changes: 1 addition & 62 deletions src/engine/sparqlExpressions/StdevExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,6 @@

/// The STDEV Expression

// Helper function to extract a double from a NumericValue variant
auto inline numValToDouble =
[]<typename T>(T&& value) -> std::optional<double> {
if constexpr (ad_utility::isSimilar<T, double> ||
ad_utility::isSimilar<T, int64_t>) {
return static_cast<double>(value);
} else {
return std::nullopt;
}
};

// Helper expression: The individual deviation squares. A DeviationExpression
// over X corresponds to the value (X - AVG(X))^2.
class DeviationExpression : public SparqlExpression {
Expand All @@ -44,57 +33,7 @@
DeviationExpression(Ptr&& child) : child_{std::move(child)} {}

// __________________________________________________________________________
ExpressionResult evaluate(EvaluationContext* context) const override {
auto impl =
[context](SingleExpressionResult auto&& el) -> ExpressionResult {
// Prepare space for result
VectorWithMemoryLimit<IdOrLiteralOrIri> exprResult{context->_allocator};
std::fill_n(std::back_inserter(exprResult), context->size(),
IdOrLiteralOrIri{Id::makeUndefined()});
bool undef = false;

auto devImpl = [&undef, &exprResult, context](auto generator) {
double sum = 0.0;
// Intermediate storage of the results returned from the child
// expression
VectorWithMemoryLimit<double> childResults{context->_allocator};

// Collect values as doubles
for (auto& inp : generator) {
const auto& n = detail::NumericValueGetter{}(std::move(inp), context);
auto v = std::visit(numValToDouble, n);
if (v.has_value()) {
childResults.push_back(v.value());
sum += v.value();
} else {
// There is a non-numeric value in the input. Therefore the entire
// result will be undef.
undef = true;
return;
}
context->cancellationHandle_->throwIfCancelled();
}

// Calculate squared deviation and save for result
double avg = sum / static_cast<double>(context->size());
for (size_t i = 0; i < childResults.size(); i++) {
exprResult.at(i) = IdOrLiteralOrIri{
ValueId::makeFromDouble(std::pow(childResults.at(i) - avg, 2))};
}
};

auto generator =
detail::makeGenerator(AD_FWD(el), context->size(), context);
devImpl(std::move(generator));

if (undef) {
return IdOrLiteralOrIri{Id::makeUndefined()};
}
return exprResult;
};
auto childRes = child_->evaluate(context);
return std::visit(impl, std::move(childRes));
};
ExpressionResult evaluate(EvaluationContext* context) const override;

// __________________________________________________________________________
AggregateStatus isAggregate() const override {
Expand All @@ -103,9 +42,9 @@

// __________________________________________________________________________
[[nodiscard]] string getCacheKey(
const VariableToColumnMap& varColMap) const override {
return absl::StrCat("[ SQ.DEVIATION ]", child_->getCacheKey(varColMap));
}

Check warning on line 47 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L45-L47

Added lines #L45 - L47 were not covered by tests

private:
// _________________________________________________________________________
Expand Down
33 changes: 25 additions & 8 deletions test/SparqlAntlrParserTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,22 @@ ::testing::Matcher<const SparqlExpression::Ptr&> matchAggregate(
AD_PROPERTY(Exp, children, ElementsAre(variableExpressionMatcher(child))),
WhenDynamicCastTo<const AggregateExpr&>(innerMatcher)));
}

// Return a matcher that checks whether a given `SparqlExpression::Ptr` actually
// points to an `AggregateExpr` and that the distinctness of the aggregate
// expression matches. It does not check the child. This is required to test
// aggregates that implicitly replace their child, like `StdevExpression`.
template <typename AggregateExpr>
::testing::Matcher<const SparqlExpression::Ptr&> matchAggregateWithoutChild(
bool distinct) {
using namespace ::testing;
using namespace builtInCallTestHelpers;
using Exp = SparqlExpression;

using enum SparqlExpression::AggregateStatus;
auto aggregateStatus = distinct ? DistinctAggregate : NonDistinctAggregate;
return Pointee(AD_PROPERTY(Exp, isAggregate, Eq(aggregateStatus)));
ullingerc marked this conversation as resolved.
Show resolved Hide resolved
}
} // namespace aggregateTestHelpers

// ___________________________________________________________
Expand Down Expand Up @@ -1927,14 +1943,15 @@ TEST(SparqlParser, aggregateExpressions) {
matchAggregate<GroupConcatExpression>(true, V{"?x"}, separator(";")));

// The STDEV expression
// TODO<ullingec> Test failing because StdevExpression replaces its child

// expectAggregate("STDEV(?x)", matchAggregate<StdevExpression>(false,
// V{"?x"})); expectAggregate("stdev(?x)",
// matchAggregate<StdevExpression>(false, V{"?x"})); A distinct stdev is
// probably not very useful, but should be possible anyway
// expectAggregate("STDEV(DISTINCT ?x)",
// matchAggregate<StdevExpression>(true, V{"?x"}));
// Here we don't match the child, because StdevExpression replaces it with a
// DeviationExpression.
expectAggregate("STDEV(?x)",
matchAggregateWithoutChild<StdevExpression>(false));
expectAggregate("stdev(?x)",
matchAggregateWithoutChild<StdevExpression>(false));
// A distinct stdev is probably not very useful, but should be possible anyway
expectAggregate("STDEV(DISTINCT ?x)",
matchAggregateWithoutChild<StdevExpression>(true));
}

TEST(SparqlParser, Quads) {
Expand Down
Loading