diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index d5c1538b77..933e68085a 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -161,8 +161,9 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext Expression arg = node.getField().accept(this, context); Aggregator aggregator = (Aggregator) repository.compile( builtinFunctionName.get().getName(), Collections.singletonList(arg)); - if (node.getCondition() != null) { - aggregator.condition(analyze(node.getCondition(), context)); + aggregator.distinct(node.getDistinct()); + if (node.condition() != null) { + aggregator.condition(analyze(node.condition(), context)); } return aggregator; } else { diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 7400ae20e6..3b78483736 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -211,7 +211,16 @@ public static UnresolvedExpression aggregate( public static UnresolvedExpression filteredAggregate( String func, UnresolvedExpression field, UnresolvedExpression condition) { - return new AggregateFunction(func, field, condition); + return new AggregateFunction(func, field).condition(condition); + } + + public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) { + return new AggregateFunction(func, field, true); + } + + public static UnresolvedExpression filteredDistinctCount( + String func, UnresolvedExpression field, UnresolvedExpression condition) { + return new AggregateFunction(func, field, true).condition(condition); } public static Function function(String funcName, UnresolvedExpression... funcArgs) { diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index 8753e35ed9..e909c46ee7 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -28,9 +28,12 @@ import java.util.Collections; import java.util.List; +import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.common.utils.StringUtils; @@ -45,7 +48,10 @@ public class AggregateFunction extends UnresolvedExpression { private final String funcName; private final UnresolvedExpression field; private final List argList; + @Setter + @Accessors(fluent = true) private UnresolvedExpression condition; + private Boolean distinct = false; /** * Constructor. @@ -62,14 +68,13 @@ public AggregateFunction(String funcName, UnresolvedExpression field) { * Constructor. * @param funcName function name. * @param field {@link UnresolvedExpression}. - * @param condition condition in aggregation filter. + * @param distinct whether distinct field is specified or not. */ - public AggregateFunction(String funcName, UnresolvedExpression field, - UnresolvedExpression condition) { + public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) { this.funcName = funcName; this.field = field; this.argList = Collections.emptyList(); - this.condition = condition; + this.distinct = distinct; } @Override diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java index e2c5fb6a39..b2172e54f1 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java @@ -157,6 +157,14 @@ public static ExprValue fromObjectValue(Object o, ExprCoreType type) { } } + public static Byte getByteValue(ExprValue exprValue) { + return exprValue.byteValue(); + } + + public static Short getShortValue(ExprValue exprValue) { + return exprValue.shortValue(); + } + public static Integer getIntegerValue(ExprValue exprValue) { return exprValue.integerValue(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index f59eb6d34a..af51d0898a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -500,6 +500,10 @@ public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } + public Aggregator distinctCount(Expression... expressions) { + return count(expressions).distinct(true); + } + public Aggregator varSamp(Expression... expressions) { return aggregate(BuiltinFunctionName.VARSAMP, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index 80944172ea..5328e11aad 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -64,6 +64,10 @@ public abstract class Aggregator @Getter @Accessors(fluent = true) protected Expression condition; + @Setter + @Getter + @Accessors(fluent = true) + protected Boolean distinct = false; /** * Create an {@link AggregationState} which will be used for aggregation. diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 3195bf3941..f1bd088967 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -28,8 +28,10 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; +import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -45,33 +47,51 @@ public CountAggregator(List arguments, ExprCoreType returnType) { @Override public CountAggregator.CountState create() { - return new CountState(); + return distinct ? new DistinctCountState() : new CountState(); } @Override protected CountState iterate(ExprValue value, CountState state) { - state.count++; + state.count(value); return state; } @Override public String toString() { - return String.format(Locale.ROOT, "count(%s)", format(getArguments())); + return distinct + ? String.format(Locale.ROOT, "count(distinct %s)", format(getArguments())) + : String.format(Locale.ROOT, "count(%s)", format(getArguments())); } /** * Count State. */ protected static class CountState implements AggregationState { - private int count; + protected int count; CountState() { this.count = 0; } + public void count(ExprValue value) { + count++; + } + @Override public ExprValue result() { return ExprValueUtils.integerValue(count); } } + + protected static class DistinctCountState extends CountState { + private final Set distinctValues = new HashSet<>(); + + @Override + public void count(ExprValue value) { + if (!distinctValues.contains(value)) { + distinctValues.add(value); + count++; + } + } + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java index 346bd2d28c..92176e8648 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java @@ -54,8 +54,8 @@ public class NamedAggregator extends Aggregator { /** * NamedAggregator. - * The aggregator properties {@link #condition} is inherited by named aggregator - * to avoid errors introduced by the property inconsistency. + * The aggregator properties {@link #condition} and {@link #distinct} + * are inherited by named aggregator to avoid errors introduced by the property inconsistency. * * @param name name * @param delegated delegated @@ -67,6 +67,7 @@ public NamedAggregator( this.name = name; this.delegated = delegated; this.condition = delegated.condition; + this.distinct = delegated.distinct; } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 7d2f033334..b0b1e7e773 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -300,6 +300,24 @@ public void variance_mapto_varPop() { ); } + @Test + public void distinct_count() { + assertAnalyzeEqual( + dsl.distinctCount(DSL.ref("integer_value", INTEGER)), + AstDSL.distinctAggregate("count", qualifiedName("integer_value")) + ); + } + + @Test + public void filtered_distinct_count() { + assertAnalyzeEqual( + dsl.distinctCount(DSL.ref("integer_value", INTEGER)) + .condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), + AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function( + ">", qualifiedName("integer_value"), intLiteral(1))) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java b/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java index a27d90f35d..af2dbf22fc 100644 --- a/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java +++ b/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java @@ -96,8 +96,8 @@ public class ExprValueUtilsTest { Lists.newArrayList(Iterables.concat(numberValues, nonNumberValues)); private static List> numberValueExtractor = Arrays.asList( - ExprValue::byteValue, - ExprValue::shortValue, + ExprValueUtils::getByteValue, + ExprValueUtils::getShortValue, ExprValueUtils::getIntegerValue, ExprValueUtils::getLongValue, ExprValueUtils::getFloatValue, diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java index cc2825858a..1db33ac9d5 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java @@ -116,6 +116,29 @@ public class AggregationTest extends ExpressionTestBase { "timestamp_value", "2040-01-01 07:00:00"))); + protected static List tuples_with_duplicates = + Arrays.asList( + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 1, + "double_value", 4d, + "struct_value", ImmutableMap.of("str", 1), + "array_value", ImmutableList.of(1))), + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 1, + "double_value", 3d, + "struct_value", ImmutableMap.of("str", 1), + "array_value", ImmutableList.of(1))), + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 2, + "double_value", 2d, + "struct_value", ImmutableMap.of("str", 2), + "array_value", ImmutableList.of(2))), + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 3, + "double_value", 1d, + "struct_value", ImmutableMap.of("str1", 1), + "array_value", ImmutableList.of(1, 2)))); + protected static List tuples_with_null_and_missing = Arrays.asList( ExprValueUtils.tupleValue( diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java index 0fdadfc692..ee183dafce 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java @@ -129,6 +129,35 @@ public void filtered_count() { assertEquals(3, result.value()); } + @Test + public void distinct_count() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)), + tuples_with_duplicates); + assertEquals(3, result.value()); + } + + @Test + public void filtered_distinct_count() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)) + .condition(dsl.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))), + tuples_with_duplicates); + assertEquals(2, result.value()); + } + + @Test + public void distinct_count_map() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("struct_value", STRUCT)), + tuples_with_duplicates); + assertEquals(3, result.value()); + } + + @Test + public void distinct_count_array() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("array_value", ARRAY)), + tuples_with_duplicates); + assertEquals(3, result.value()); + } + @Test public void count_with_missing() { ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)), @@ -166,6 +195,9 @@ public void valueOf() { public void test_to_string() { Aggregator countAggregator = dsl.count(DSL.ref("integer_value", INTEGER)); assertEquals("count(integer_value)", countAggregator.toString()); + + countAggregator = dsl.distinctCount(DSL.ref("integer_value", INTEGER)); + assertEquals("count(distinct integer_value)", countAggregator.toString()); } @Test diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 1d6d172981..275666e7ba 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -357,6 +357,19 @@ Example:: | 2.8613807855648994 | +--------------------+ +DISTINCT COUNT Aggregation +-------------------------- + +To get the count of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the count aggregation. Example:: + + os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts; + fetched rows / total rows = 1/1 + +--------------------------+-----------------+ + | COUNT(DISTINCT gender) | COUNT(gender) | + |--------------------------+-----------------| + | 2 | 4 | + +--------------------------+-----------------+ + HAVING Clause ============= @@ -456,3 +469,16 @@ The ``FILTER`` clause can be used in aggregation functions without GROUP BY as w | 4 | 1 | +--------------+------------+ +Distinct count aggregate with FILTER +------------------------------------ + +The ``FILTER`` clause is also used in distinct count to do the filtering before count the distinct values of specific field. For example:: + + os> SELECT COUNT(DISTINCT firstname) FILTER(WHERE age > 30) AS distinct_count FROM accounts + fetched rows / total rows = 1/1 + +------------------+ + | distinct_count | + |------------------| + | 3 | + +------------------+ + diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index f6dad255ef..ee91d86d14 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -302,3 +302,18 @@ PPL query:: | 36 | 32 | M | +------------+------------+----------+ +Example 7: Calculate the distinct count of a field +================================================== + +To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts. + +PPL query:: + + os> source=accounts | stats count(gender), distinct_count(gender); + fetched rows / total rows = 1/1 + +-----------------+--------------------------+ + | count(gender) | distinct_count(gender) | + |-----------------+--------------------------| + | 4 | 2 | + +-----------------+--------------------------+ + diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java index ff3ad2a6c8..4a9603fe6b 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java @@ -77,6 +77,19 @@ public void testStatsCountAll() throws IOException { verifyDataRows(response, rows(1000)); } + @Test + public void testStatsDistinctCount() throws IOException { + JSONObject response = + executeQuery(String.format("source=%s | stats distinct_count(gender)", TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("distinct_count(gender)", null, "integer")); + verifyDataRows(response, rows(2)); + + response = + executeQuery(String.format("source=%s | stats dc(age)", TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("dc(age)", null, "integer")); + verifyDataRows(response, rows(21)); + } + @Test public void testStatsMin() throws IOException { JSONObject response = executeQuery(String.format( diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java index 3cbb222afe..33cddc6f1f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -30,7 +30,15 @@ protected void init() throws Exception { } @Test - void filteredAggregateWithSubquery() throws IOException { + void filteredAggregatePushedDown() throws IOException { + JSONObject response = executeQuery( + "SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TEST_INDEX_BANK); + verifySchema(response, schema("COUNT(*)", null, "integer")); + verifyDataRows(response, rows(3)); + } + + @Test + void filteredAggregateNotPushedDown() throws IOException { JSONObject response = executeQuery( "SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK + ") AS a"); diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java index b92ca17238..52373a72e3 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java @@ -29,6 +29,8 @@ import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRowsInOrder; + import org.json.JSONObject; import org.junit.Test; @@ -40,6 +42,7 @@ public class WindowFunctionIT extends SQLIntegTestCase { @Override protected void init() throws Exception { loadIndex(Index.BANK_WITH_NULL_VALUES); + loadIndex(Index.BANK); } @Test @@ -74,4 +77,49 @@ public void testOrderByNullLast() { rows(null, 7)); } + @Test + public void testDistinctCountOverNull() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER() " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRows(response, + rows("Duke Willmington", 2), + rows("Bond", 2), + rows("Bates", 2), + rows("Adams", 2), + rows("Ratliff", 2), + rows("Ayala", 2), + rows("Mcpherson", 2)); + } + + @Test + public void testDistinctCountOver() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER(ORDER BY lastname) " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRowsInOrder(response, + rows("Adams", 1), + rows("Ayala", 2), + rows("Bates", 2), + rows("Bond", 2), + rows("Duke Willmington", 2), + rows("Mcpherson", 2), + rows("Ratliff", 2)); + } + + @Test + public void testDistinctCountPartition() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER(PARTITION BY gender ORDER BY lastname) " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRowsInOrder(response, + rows("Ayala", 1), + rows("Bates", 1), + rows("Mcpherson", 1), + rows("Adams", 1), + rows("Bond", 1), + rows("Duke Willmington", 1), + rows("Ratliff", 1)); + } + } diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 45aa658783..0c0648a937 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -9,4 +9,6 @@ SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT VAR_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights -SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java index 73d58d793e..cd793c9046 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java @@ -43,11 +43,9 @@ /** * Abstract Aggregation Builder. - * - * @param type of the actual AggregationBuilder to be built. */ @RequiredArgsConstructor -public class AggregationBuilderHelper { +public class AggregationBuilderHelper { private final ExpressionSerializer serializer; @@ -57,7 +55,7 @@ public class AggregationBuilderHelper { * @param expression Expression * @return AggregationBuilder */ - public T build(Expression expression, Function fieldBuilder, + public T build(Expression expression, Function fieldBuilder, Function scriptBuilder) { if (expression instanceof ReferenceExpression) { String fieldName = ((ReferenceExpression) expression).getAttr(); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java index b1aff2c5b4..d137cce75d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java @@ -42,11 +42,11 @@ */ public class BucketAggregationBuilder { - private final AggregationBuilderHelper> helper; + private final AggregationBuilderHelper helper; public BucketAggregationBuilder( ExpressionSerializer serializer) { - this.helper = new AggregationBuilderHelper<>(serializer); + this.helper = new AggregationBuilderHelper(serializer); } /** @@ -62,8 +62,8 @@ public List> build( .missingBucket(true) .order(groupPair.getRight()); resultBuilder - .add(helper.build(groupPair.getLeft().getDelegated(), valuesSourceBuilder::field, - valuesSourceBuilder::script)); + .add((CompositeValuesSourceBuilder) helper.build(groupPair.getLeft().getDelegated(), + valuesSourceBuilder::field, valuesSourceBuilder::script)); } return resultBuilder.build(); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 3d40258288..754da49862 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -32,11 +32,13 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; @@ -57,11 +59,14 @@ public class MetricAggregationBuilder extends ExpressionNodeVisitor, Object> { - private final AggregationBuilderHelper> helper; + private final AggregationBuilderHelper helper; private final FilterQueryBuilder filterBuilder; + /** + * Constructor. + */ public MetricAggregationBuilder(ExpressionSerializer serializer) { - this.helper = new AggregationBuilderHelper<>(serializer); + this.helper = new AggregationBuilderHelper(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -88,9 +93,26 @@ public Pair visitNamedAggregator( NamedAggregator node, Object context) { Expression expression = node.getArguments().get(0); Expression condition = node.getDelegated().condition(); + Boolean distinct = node.getDelegated().distinct(); String name = node.getName(); + String functionName = node.getFunctionName().getFunctionName().toLowerCase(Locale.ROOT); + + if (distinct) { + switch (functionName) { + case "count": + return make( + AggregationBuilders.cardinality(name), + expression, + condition, + name, + new SingleValueParser(name)); + default: + throw new IllegalStateException(String.format( + "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); + } + } - switch (node.getFunctionName().getFunctionName()) { + switch (functionName) { case "avg": return make( AggregationBuilders.avg(name), @@ -176,6 +198,24 @@ private Pair make( return Pair.of(aggregationBuilder, parser); } + /** + * Make {@link CardinalityAggregationBuilder} for distinct count aggregations. + */ + private Pair make(CardinalityAggregationBuilder builder, + Expression expression, + Expression condition, + String name, + MetricParser parser) { + CardinalityAggregationBuilder aggregationBuilder = + helper.build(expression, builder::field, builder::script); + if (condition != null) { + return Pair.of( + makeFilterAggregation(aggregationBuilder, condition, name), + FilterParser.builder().name(name).metricsParser(parser).build()); + } + return Pair.of(aggregationBuilder, parser); + } + /** * Replace star or literal with OpenSearch metadata field "_index". Because: 1) Analyzer already * converts * to string literal, literal check here can handle both COUNT(*) and COUNT(1). 2) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 95a2383475..129814d45f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -32,6 +32,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; @@ -42,6 +43,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Arrays; +import java.util.Collections; import java.util.List; import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; @@ -51,19 +53,21 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.MaxAggregator; import org.opensearch.sql.expression.aggregation.MinAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.aggregation.SumAggregator; -import org.opensearch.sql.expression.aggregation.VarianceAggregator; +import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) class MetricAggregationBuilderTest { + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); @Mock private ExpressionSerializer serializer; @@ -258,6 +262,61 @@ void should_build_stddevSamp_aggregation() { stddevSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_cardinality_aggregation() { + assertEquals( + "{\n" + + " \"count(distinct name)\" : {\n" + + " \"cardinality\" : {\n" + + " \"field\" : \"name\"\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Collections.singletonList(named("count(distinct name)", new CountAggregator( + Collections.singletonList(ref("name", STRING)), INTEGER).distinct(true))))); + } + + @Test + void should_build_filtered_cardinality_aggregation() { + assertEquals( + "{\n" + + " \"count(distinct name) filter(where age > 30)\" : {\n" + + " \"filter\" : {\n" + + " \"range\" : {\n" + + " \"age\" : {\n" + + " \"from\" : 30,\n" + + " \"to\" : null,\n" + + " \"include_lower\" : false,\n" + + " \"include_upper\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " },\n" + + " \"aggregations\" : {\n" + + " \"count(distinct name) filter(where age > 30)\" : {\n" + + " \"cardinality\" : {\n" + + " \"field\" : \"name\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + buildQuery(Collections.singletonList(named( + "count(distinct name) filter(where age > 30)", + new CountAggregator(Collections.singletonList(ref("name", STRING)), INTEGER) + .condition(dsl.greater(ref("age", INTEGER), literal(30))) + .distinct(true))))); + } + + @Test + void should_throw_exception_for_unsupported_distinct_aggregator() { + assertThrows(IllegalStateException.class, + () -> buildQuery(Collections.singletonList(named("avg(distinct age)", new AvgAggregator( + Collections.singletonList(ref("name", STRING)), STRING).distinct(true)))), + "unsupported distinct aggregator avg"); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index d552ad0756..fa3c5b64b7 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -135,6 +135,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS #statsFunctionCall | COUNT LT_PRTHS RT_PRTHS #countAllFunctionCall + | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS #distinctCountFunctionCall | percentileAggFunction #percentileAggFunctionCall ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 9fdf8d636d..7da4f90cf0 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -35,6 +35,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; @@ -203,6 +204,11 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex return new AggregateFunction("count", AllFields.of()); } + @Override + public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { + return new AggregateFunction("count", visit(ctx.valueExpression()), true); + } + @Override public UnresolvedExpression visitPercentileAggFunction(PercentileAggFunctionContext ctx) { return new AggregateFunction(ctx.PERCENTILE().getText(), visit(ctx.aggField), diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 71ef692abf..bfb975ba74 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -37,6 +37,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.defaultFieldsArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultSortFieldArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultStatsArgs; +import static org.opensearch.sql.ast.dsl.AstDSL.distinctAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.equalTo; import static org.opensearch.sql.ast.dsl.AstDSL.eval; @@ -460,6 +461,19 @@ public void testCountFuncCallExpr() { )); } + @Test + public void testDistinctCount() { + assertEqual("source=t | stats distinct_count(a)", + agg( + relation("t"), + exprList( + alias("distinct_count(a)", + distinctAggregate("count", field("a")))), + emptyList(), + emptyList(), + defaultStatsArgs())); + } + @Test public void testEvalFuncCallExpr() { assertEqual("source=t | eval f=abs(a)", diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index fe5526621f..61d2f5990f 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -336,8 +336,10 @@ caseFuncAlternative ; aggregateFunction - : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall - | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall + : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET + #regularAggregateFunctionCall + | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall + | COUNT LR_BRACKET DISTINCT functionArg RR_BRACKET #distinctCountFunctionCall ; filterClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index b1630aed50..8dda63b750 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -43,6 +43,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CountStarFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DateLiteralContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IsNullPredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext; @@ -171,7 +172,7 @@ public UnresolvedExpression visitShowDescribePattern( public UnresolvedExpression visitFilteredAggregationFunctionCall( OpenSearchSQLParser.FilteredAggregationFunctionCallContext ctx) { AggregateFunction agg = (AggregateFunction) visit(ctx.aggregateFunction()); - return new AggregateFunction(agg.getFuncName(), agg.getField(), visit(ctx.filterClause())); + return agg.condition(visit(ctx.filterClause())); } @Override @@ -212,6 +213,14 @@ public UnresolvedExpression visitRegularAggregateFunctionCall( visitFunctionArg(ctx.functionArg())); } + @Override + public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { + return new AggregateFunction( + ctx.COUNT().getText(), + visitFunctionArg(ctx.functionArg()), + true); + } + @Override public UnresolvedExpression visitCountStarFunctionCall(CountStarFunctionCallContext ctx) { return new AggregateFunction("COUNT", AllFields.of()); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 1d9516f816..44c84495c2 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -36,6 +36,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.distinctAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -50,6 +51,7 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; @@ -167,6 +169,17 @@ void can_build_implicit_group_by_for_aggregator_in_having_clause() { alias("AVG(age)", aggregate("AVG", qualifiedName("age")))))); } + @Test + void can_build_distinct_aggregator() { + assertThat( + buildAggregation("SELECT COUNT(DISTINCT name) FROM test group by age"), + allOf( + hasGroupByItems(alias("age", qualifiedName("age"))), + hasAggregators( + alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName( + "name")))))); + } + @Test void should_build_nothing_if_no_group_by_and_no_aggregators_in_select() { assertNull(buildAggregation("SELECT name FROM test")); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index e4e8028f05..e101eb9404 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -431,6 +431,23 @@ public void canBuildVariance() { buildExprAst("variance(age)")); } + @Test + public void distinctCount() { + assertEquals( + AstDSL.distinctAggregate("count", qualifiedName("name")), + buildExprAst("count(distinct name)") + ); + } + + @Test + public void filteredDistinctCount() { + assertEquals( + AstDSL.filteredDistinctCount("count", qualifiedName("name"), function( + ">", qualifiedName("age"), intLiteral(30))), + buildExprAst("count(distinct name) filter(where age > 30)") + ); + } + private Node buildExprAst(String expr) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(expr)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer));