Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support distinct count aggregation #167

Merged
merged 26 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b39e7b6
Support construct AggregationResponseParser during Aggregator build s…
penghuo Jun 8, 2021
6f5350d
support distinct count aggregation
chloe-zh Jun 9, 2021
e30b685
fixed tests
chloe-zh Jun 9, 2021
43dad8d
Merge remote-tracking branch 'upstream/develop' into issue/#100
chloe-zh Jun 9, 2021
866d71d
Merge remote-tracking branch 'upstream/develop' into issue/#100
chloe-zh Jun 9, 2021
8a6ca20
update
chloe-zh Jun 9, 2021
43cbd17
updated user doc
chloe-zh Jun 9, 2021
392c96c
Update: support only count for distinct aggregations
chloe-zh Jun 11, 2021
078eae7
Update doc; removed distinct start
chloe-zh Jun 11, 2021
f10b282
Removed unnecessary methods
chloe-zh Jun 11, 2021
df81cfa
update
chloe-zh Jun 11, 2021
8632f80
Impl stddev and variance function in SQL and PPL (#115)
penghuo Jun 11, 2021
9ff2793
Fix the aggregation filter missing in named aggregators (#123)
chloe-zh Jun 11, 2021
5b02a43
Merge remote-tracking branch 'upstream/develop' into issue/#100
chloe-zh Jun 11, 2021
94a045f
update
chloe-zh Jun 11, 2021
11a9758
modified comparison test
chloe-zh Jun 14, 2021
d5dc9eb
removed a comparison test and added it to aggregationIT
chloe-zh Jun 15, 2021
684a742
added ppl IT test cases; added window function test cases
chloe-zh Jun 15, 2021
c750f59
moved distinct window function test cases to WindowsIT
chloe-zh Jun 16, 2021
9fa771d
added ut
chloe-zh Jun 16, 2021
5d42554
update
chloe-zh Jun 16, 2021
80a4c61
update
chloe-zh Jun 17, 2021
df04751
Merge remote-tracking branch 'upstream/main' into issue/#100
chloe-zh Jul 23, 2021
86db16d
addressed comments
chloe-zh Jul 26, 2021
f5cece5
added test cases to meet the coverage requirement
chloe-zh Jul 26, 2021
832e019
added test cases for distinct count map and array types
chloe-zh Jul 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Expression arg = node.getField().accept(this, context);
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
aggregator.distinct(node.getDistinct());
if (node.condition() != null) {
aggregator.condition(analyze(node.condition(), context));
}
return aggregator;
} else {
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,16 @@ public static UnresolvedExpression aggregate(

public static UnresolvedExpression filteredAggregate(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, condition);
return new AggregateFunction(func, field).condition(condition);
}

public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) {
return new AggregateFunction(func, field, true);
}

public static UnresolvedExpression filteredDistinctCount(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, true).condition(condition);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@

import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.common.utils.StringUtils;

Expand All @@ -45,7 +48,10 @@ public class AggregateFunction extends UnresolvedExpression {
private final String funcName;
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
@Setter
@Accessors(fluent = true)
private UnresolvedExpression condition;
private Boolean distinct = false;

/**
* Constructor.
Expand All @@ -62,14 +68,13 @@ public AggregateFunction(String funcName, UnresolvedExpression field) {
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param condition condition in aggregation filter.
* @param distinct whether distinct field is specified or not.
*/
public AggregateFunction(String funcName, UnresolvedExpression field,
UnresolvedExpression condition) {
public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.condition = condition;
this.distinct = distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ public static ExprValue fromObjectValue(Object o, ExprCoreType type) {
}
}

public static Byte getByteValue(ExprValue exprValue) {
return exprValue.byteValue();
}

public static Short getShortValue(ExprValue exprValue) {
return exprValue.shortValue();
}

public static Integer getIntegerValue(ExprValue exprValue) {
return exprValue.integerValue();
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ public Aggregator count(Expression... expressions) {
return aggregate(BuiltinFunctionName.COUNT, expressions);
}

public Aggregator distinctCount(Expression... expressions) {
return count(expressions).distinct(true);
}

public Aggregator varSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARSAMP, expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
@Accessors(fluent = true)
protected Expression condition;
@Setter
@Getter
@Accessors(fluent = true)
protected Boolean distinct = false;

/**
* Create an {@link AggregationState} which will be used for aggregation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand All @@ -45,33 +47,51 @@ public CountAggregator(List<Expression> arguments, ExprCoreType returnType) {

@Override
public CountAggregator.CountState create() {
return new CountState();
return distinct ? new DistinctCountState() : new CountState();
}

@Override
protected CountState iterate(ExprValue value, CountState state) {
state.count++;
state.count(value);
return state;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "count(%s)", format(getArguments()));
return distinct
? String.format(Locale.ROOT, "count(distinct %s)", format(getArguments()))
: String.format(Locale.ROOT, "count(%s)", format(getArguments()));
}

/**
* Count State.
*/
protected static class CountState implements AggregationState {
private int count;
protected int count;

CountState() {
this.count = 0;
}

public void count(ExprValue value) {
count++;
}

@Override
public ExprValue result() {
return ExprValueUtils.integerValue(count);
}
}

protected static class DistinctCountState extends CountState {
private final Set<ExprValue> distinctValues = new HashSet<>();

@Override
public void count(ExprValue value) {
if (!distinctValues.contains(value)) {
distinctValues.add(value);
count++;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public class NamedAggregator extends Aggregator<AggregationState> {

/**
* NamedAggregator.
* The aggregator properties {@link #condition} is inherited by named aggregator
* to avoid errors introduced by the property inconsistency.
* The aggregator properties {@link #condition} and {@link #distinct}
* are inherited by named aggregator to avoid errors introduced by the property inconsistency.
*
* @param name name
* @param delegated delegated
Expand All @@ -67,6 +67,7 @@ public NamedAggregator(
this.name = name;
this.delegated = delegated;
this.condition = delegated.condition;
this.distinct = delegated.distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,24 @@ public void variance_mapto_varPop() {
);
}

@Test
public void distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
AstDSL.distinctAggregate("count", qualifiedName("integer_value"))
);
}

@Test
public void filtered_distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function(
">", qualifiedName("integer_value"), intLiteral(1)))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ public class ExprValueUtilsTest {
Lists.newArrayList(Iterables.concat(numberValues, nonNumberValues));

private static List<Function<ExprValue, Object>> numberValueExtractor = Arrays.asList(
ExprValue::byteValue,
ExprValue::shortValue,
ExprValueUtils::getByteValue,
ExprValueUtils::getShortValue,
ExprValueUtils::getIntegerValue,
ExprValueUtils::getLongValue,
ExprValueUtils::getFloatValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ public class AggregationTest extends ExpressionTestBase {
"timestamp_value",
"2040-01-01 07:00:00")));

protected static List<ExprValue> tuples_with_duplicates =
Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 4d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 3d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2, "double_value", 2d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3, "double_value", 1d)));

protected static List<ExprValue> tuples_with_null_and_missing =
Arrays.asList(
ExprValueUtils.tupleValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ public void filtered_count() {
assertEquals(3, result.value());
}

@Test
public void distinct_count() {
chloe-zh marked this conversation as resolved.
Show resolved Hide resolved
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void filtered_distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))),
tuples_with_duplicates);
assertEquals(2, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down Expand Up @@ -166,6 +181,9 @@ public void valueOf() {
public void test_to_string() {
Aggregator countAggregator = dsl.count(DSL.ref("integer_value", INTEGER));
assertEquals("count(integer_value)", countAggregator.toString());

countAggregator = dsl.distinctCount(DSL.ref("integer_value", INTEGER));
assertEquals("count(distinct integer_value)", countAggregator.toString());
}

@Test
Expand Down
26 changes: 26 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,19 @@ Example::
| 2.8613807855648994 |
+--------------------+

DISTINCT COUNT Aggregation
--------------------------

To get the count of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the count aggregation. Example::

os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts;
fetched rows / total rows = 1/1
+--------------------------+-----------------+
| COUNT(DISTINCT gender) | COUNT(gender) |
|--------------------------+-----------------|
| 2 | 4 |
+--------------------------+-----------------+

HAVING Clause
=============

Expand Down Expand Up @@ -456,3 +469,16 @@ The ``FILTER`` clause can be used in aggregation functions without GROUP BY as w
| 4 | 1 |
+--------------+------------+

Distinct count aggregate with FILTER
------------------------------------

The ``FILTER`` clause is also used in distinct count to do the filtering before count the distinct values of specific field. For example::

os> SELECT COUNT(DISTINCT firstname) FILTER(WHERE age > 30) AS distinct_count FROM accounts
fetched rows / total rows = 1/1
+------------------+
| distinct_count |
|------------------|
| 3 |
+------------------+

15 changes: 15 additions & 0 deletions docs/user/ppl/cmd/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,18 @@ PPL query::
| 36 | 32 | M |
+------------+------------+----------+

Example 7: Calculate the distinct count of a field
==================================================

To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts.

PPL query::

os> source=accounts | stats count(gender), distinct_count(gender);
fetched rows / total rows = 1/1
+-----------------+--------------------------+
| count(gender) | distinct_count(gender) |
|-----------------+--------------------------|
| 4 | 2 |
+-----------------+--------------------------+

Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ public void testStatsCountAll() throws IOException {
verifyDataRows(response, rows(1000));
}

@Test
public void testStatsDistinctCount() throws IOException {
JSONObject response =
executeQuery(String.format("source=%s | stats distinct_count(gender)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("distinct_count(gender)", null, "integer"));
verifyDataRows(response, rows(2));

response =
executeQuery(String.format("source=%s | stats dc(age)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("dc(age)", null, "integer"));
verifyDataRows(response, rows(21));
}

@Test
public void testStatsMin() throws IOException {
JSONObject response = executeQuery(String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ protected void init() throws Exception {
}

@Test
void filteredAggregateWithSubquery() throws IOException {
void filteredAggregatePushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TEST_INDEX_BANK);
verifySchema(response, schema("COUNT(*)", null, "integer"));
verifyDataRows(response, rows(3));
}

@Test
void filteredAggregateNotPushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK
+ ") AS a");
Expand Down
Loading