Skip to content

Add returnsZeroOnEmptyInput to AggregationFunctionMetadata #17963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2416,6 +2416,10 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Session sessio
builder.orderSensitive();
}

if (aggregationFunctionMetadata.returnsZeroOnEmptyInput()) {
builder.returnsZeroOnEmptyInput();
}

if (!aggregationFunctionMetadata.getIntermediateTypes().isEmpty()) {
FunctionBinding functionBinding = toFunctionBinding(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionSignature);
aggregationFunctionMetadata.getIntermediateTypes().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ private static AggregationHeader parseHeader(AnnotatedElement aggregationDefinit
parseDescription(aggregationDefinition, outputFunction),
aggregationAnnotation.decomposable(),
aggregationAnnotation.isOrderSensitive(),
aggregationAnnotation.returnsZeroOnEmptyInput(),
aggregationAnnotation.hidden(),
aggregationDefinition.getAnnotationsByType(Deprecated.class).length > 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ public class AggregationHeader
private final Optional<String> description;
private final boolean decomposable;
private final boolean orderSensitive;
private final boolean returnsZeroOnEmptyInput;
private final boolean hidden;
private final boolean deprecated;

public AggregationHeader(String name, Optional<String> description, boolean decomposable, boolean orderSensitive, boolean hidden, boolean deprecated)
public AggregationHeader(String name, Optional<String> description, boolean decomposable, boolean orderSensitive, boolean returnsZeroOnEmptyInput, boolean hidden, boolean deprecated)
{
this.name = requireNonNull(name, "name cannot be null");
this.description = requireNonNull(description, "description cannot be null");
this.decomposable = decomposable;
this.orderSensitive = orderSensitive;
this.returnsZeroOnEmptyInput = returnsZeroOnEmptyInput;
this.hidden = hidden;
this.deprecated = deprecated;
}
Expand All @@ -57,6 +59,11 @@ public boolean isOrderSensitive()
return orderSensitive;
}

public boolean returnsZeroOnEmptyInput()
{
return returnsZeroOnEmptyInput;
}

public boolean isHidden()
{
return hidden;
Expand All @@ -75,6 +82,7 @@ public String toString()
.add("description", description)
.add("decomposable", decomposable)
.add("orderSensitive", orderSensitive)
.add("returnsZeroOnEmptyInput", returnsZeroOnEmptyInput)
.add("hidden", hidden)
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import static io.trino.util.Failures.checkCondition;
import static io.trino.util.Failures.internalError;

@AggregationFunction("approx_distinct")
@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true)
public final class ApproximateCountDistinctAggregation
{
private static final double LOWEST_MAX_STANDARD_ERROR = 0.0040625;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import static io.trino.spi.type.BigintType.BIGINT;

@AggregationFunction("approx_distinct")
@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true)
public final class BooleanApproximateCountDistinctAggregation
{
private BooleanApproximateCountDistinctAggregation() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import io.trino.spi.function.SqlType;
import io.trino.spi.type.StandardTypes;

@AggregationFunction("approx_distinct")
@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true)
public final class BooleanDefaultApproximateCountDistinctAggregation
{
// this value is ignored for boolean, but this is left here for completeness
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import static io.trino.spi.type.BigintType.BIGINT;

@AggregationFunction("count")
@AggregationFunction(value = "count", returnsZeroOnEmptyInput = true)
public final class CountAggregation
{
private CountAggregation() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import static io.trino.spi.type.BigintType.BIGINT;

@AggregationFunction("count")
@AggregationFunction(value = "count", returnsZeroOnEmptyInput = true)
@Description("Counts the non-null values")
public final class CountColumn
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import static io.trino.spi.type.BigintType.BIGINT;

@AggregationFunction("count_if")
@AggregationFunction(value = "count_if", returnsZeroOnEmptyInput = true)
public final class CountIfAggregation
{
private CountIfAggregation() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.XX_HASH_64;

@AggregationFunction("approx_distinct")
@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true)
public final class DefaultApproximateCountDistinctAggregation
{
private static final double DEFAULT_STANDARD_ERROR = 0.023;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ private static AggregationFunctionMetadata createAggregationFunctionMetadata(Agg
if (details.isOrderSensitive()) {
builder.orderSensitive();
}
if (details.returnsZeroOnEmptyInput()) {
builder.returnsZeroOnEmptyInput();
}
if (details.isDecomposable()) {
for (AccumulatorStateDetails<?> stateDetail : stateDetails) {
builder.intermediateType(stateDetail.getSerializedType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.Metadata;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.type.Type;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
Expand Down Expand Up @@ -189,8 +190,9 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<A
Optional.empty(),
Optional.empty(),
Optional.empty());
String signatureName = aggregation.getResolvedFunction().getSignature().getName();
if (signatureName.equals("count") || signatureName.equals("count_if") || signatureName.equals("approx_distinct")) {

AggregationFunctionMetadata aggregationFunctionMetadata = metadata.getAggregationFunctionMetadata(session, aggregation.getResolvedFunction());
if (aggregationFunctionMetadata.returnsZeroOnEmptyInput()) {
Symbol newSymbol = symbolAllocator.newSymbol("expr", symbolAllocator.getTypes().get(entry.getKey()));
aggregations.put(newSymbol, newAggregation);
coalesceSymbolsBuilder.put(newSymbol, entry.getKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
*/
boolean isOrderSensitive() default false;

/**
* Indicates whether the function returns value BIGINT 0 when no input is provided.
* The SQL specification demands it for COUNT function.
* In trino, COUNT_IF and APPROX_DISTINCT also have this characteristic.
*/
boolean returnsZeroOnEmptyInput() default false;

boolean hidden() default false;

String[] alias() default {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
public class AggregationFunctionMetadata
{
private final boolean orderSensitive;

private final boolean returnsZeroOnEmptyInput;
private final List<TypeSignature> intermediateTypes;

private AggregationFunctionMetadata(boolean orderSensitive, List<TypeSignature> intermediateTypes)
private AggregationFunctionMetadata(boolean orderSensitive, boolean returnsZeroOnEmptyInput, List<TypeSignature> intermediateTypes)
{
this.orderSensitive = orderSensitive;
this.returnsZeroOnEmptyInput = returnsZeroOnEmptyInput;
this.intermediateTypes = List.copyOf(requireNonNull(intermediateTypes, "intermediateTypes is null"));
}

Expand All @@ -40,6 +43,11 @@ public boolean isOrderSensitive()
return orderSensitive;
}

public boolean returnsZeroOnEmptyInput()
{
return returnsZeroOnEmptyInput;
}

public boolean isDecomposable()
{
return !intermediateTypes.isEmpty();
Expand All @@ -55,6 +63,7 @@ public String toString()
{
return new StringJoiner(", ", AggregationFunctionMetadata.class.getSimpleName() + "[", "]")
.add("orderSensitive=" + orderSensitive)
.add("returnsZeroOnEmptyInput=" + returnsZeroOnEmptyInput)
.add("intermediateTypes=" + intermediateTypes)
.toString();
}
Expand All @@ -68,6 +77,7 @@ public static class AggregationFunctionMetadataBuilder
{
private boolean orderSensitive;
private final List<TypeSignature> intermediateTypes = new ArrayList<>();
private boolean returnsZeroOnEmptyInput;

private AggregationFunctionMetadataBuilder() {}

Expand All @@ -77,6 +87,12 @@ public AggregationFunctionMetadataBuilder orderSensitive()
return this;
}

public AggregationFunctionMetadataBuilder returnsZeroOnEmptyInput()
{
this.returnsZeroOnEmptyInput = true;
return this;
}

public AggregationFunctionMetadataBuilder intermediateType(Type type)
{
this.intermediateTypes.add(type.getTypeSignature());
Expand All @@ -91,7 +107,7 @@ public AggregationFunctionMetadataBuilder intermediateType(TypeSignature type)

public AggregationFunctionMetadata build()
{
return new AggregationFunctionMetadata(orderSensitive, intermediateTypes);
return new AggregationFunctionMetadata(orderSensitive, returnsZeroOnEmptyInput, intermediateTypes);
}
}
}