Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
* Special rule for Pinot, this rule is fixed to always insert an exchange or sort exchange below the WINDOW node.
* TODO:
* 1. Add support for more than one window group
* 2. Add support for functions other than aggregation functions (AVG, COUNT, MAX, MIN, SUM, BOOL_AND, BOOL_OR)
* 2. Add support for functions other than:
* a. Aggregation functions (AVG, COUNT, MAX, MIN, SUM, BOOL_AND, BOOL_OR)
* b. Ranking functions (ROW_NUMBER, RANK, DENSE_RANK)
* 3. Add support for custom frames
*/
public class PinotWindowExchangeNodeInsertRule extends RelOptRule {
Expand All @@ -62,7 +64,8 @@ public class PinotWindowExchangeNodeInsertRule extends RelOptRule {
// Supported window functions
// OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR
private static final Set<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND = ImmutableSet.of(SqlKind.SUM, SqlKind.SUM0,
SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.OTHER_FUNCTION);
SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK, SqlKind.DENSE_RANK,
SqlKind.OTHER_FUNCTION);

public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) {
super(operand(LogicalWindow.class, any()), factory, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ public class WindowNode extends AbstractPlanNode {

/**
* Enum to denote the type of window frame
* ROW - ROW type window frame
* ROWS - ROWS type window frame
* RANGE - RANGE type window frame
*/
public enum WindowFrameType {
ROW,
ROWS,
RANGE
}

Expand Down Expand Up @@ -95,7 +95,7 @@ public WindowNode(int planFragmentId, List<Window.Group> windowGroups, List<RexL
_lowerBound = Integer.MIN_VALUE;
// Upper bound can only be unbounded following or current row for now
_upperBound = windowGroup.upperBound.isUnbounded() ? Integer.MAX_VALUE : 0;
_windowFrameType = windowGroup.isRows ? WindowFrameType.ROW : WindowFrameType.RANGE;
_windowFrameType = windowGroup.isRows ? WindowFrameType.ROWS : WindowFrameType.RANGE;

// TODO: Constants are used to store constants needed such as the frame literals. For now just save this, need to
// extract the constant values into bounds as a part of frame support.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ protected Object[][] provideQueries() {
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2, a.col1), MIN(a.col3) OVER (ORDER BY a.col2, "
+ "a.col1) FROM a"},
new Object[]{"SELECT a.col1, ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col2) FROM a"},
new Object[]{"SELECT col1, total, rank FROM (SELECT a.col1 as col1, count(*) as total, "
+ "RANK() OVER(ORDER BY count(*) DESC) AS rank FROM a GROUP BY a.col1) WHERE rank < 5"},
new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1) FROM a"},
new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"},
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0"
+ " AND a.col2 = 'a' GROUP BY a.col1"},
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
*
* The window functions supported today are:
* Aggregation: SUM/COUNT/MIN/MAX/AVG/BOOL_OR/BOOL_AND aggregations [RANGE window type only]
* Ranking: ROW_NUMBER ranking functions [ROWS window type only]
* Ranking: ROW_NUMBER [ROWS window type only], RANK, DENSE_RANK [RANGE window type only] ranking functions
* Value: [none]
*
* Unlike the AggregateOperator which will output one row per group, the WindowAggregateOperator
Expand All @@ -82,6 +82,8 @@ public class WindowAggregateOperator extends MultiStageOperator {

// List of window functions which can only be applied as ROWS window frame type
private static final Set<String> ROWS_ONLY_FUNCTION_NAMES = ImmutableSet.of("ROW_NUMBER");
// List of ranking window functions whose output depends on the ordering of input rows and not on the actual values
private static final Set<String> RANKING_FUNCTION_NAMES = ImmutableSet.of("RANK", "DENSE_RANK");

private final MultiStageOperator _inputOperator;
private final List<RexExpression> _groupSet;
Expand Down Expand Up @@ -191,7 +193,7 @@ private void validateAggregationCalls(String functionName,
}

if (ROWS_ONLY_FUNCTION_NAMES.contains(functionName)) {
Preconditions.checkState(_windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.ROW
Preconditions.checkState(_windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.ROWS
&& _windowFrame.isUpperBoundCurrentRow(),
String.format("%s must be of ROW frame type and have CURRENT ROW as the upper bound", functionName));
} else {
Expand Down Expand Up @@ -225,12 +227,13 @@ private TransferableBlock produceWindowAggregatedBlock() {
List<Object[]> rows = new ArrayList<>(_numRows);
if (_windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.RANGE) {
// All aggregation window functions only support RANGE type today (SUM/AVG/MIN/MAX/COUNT/BOOL_AND/BOOL_OR)
// RANK and DENSE_RANK ranking window functions also only support RANGE type today
for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) {
Key partitionKey = e.getKey();
List<Object[]> rowList = e.getValue();
for (Object[] existingRow : rowList) {
Object[] row = new Object[existingRow.length + _aggCalls.size()];
Key orderKey = _isPartitionByOnly ? emptyOrderKey
Key orderKey = (_isPartitionByOnly && CollectionUtils.isEmpty(_orderSetInfo.getOrderSet())) ? emptyOrderKey
: AggregationUtils.extractRowKey(existingRow, _orderSetInfo.getOrderSet());
System.arraycopy(existingRow, 0, row, 0, existingRow.length);
for (int i = 0; i < _windowAccumulators.length; i++) {
Expand Down Expand Up @@ -298,7 +301,7 @@ private boolean consumeInputBlocks() {
_partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row);
// Only need to accumulate the aggregate function values for RANGE type. ROW type can be calculated as
// we output the rows since the aggregation value depends on the neighboring rows.
Key orderKey = _isPartitionByOnly ? emptyOrderKey
Key orderKey = (_isPartitionByOnly && CollectionUtils.isEmpty(_orderSetInfo.getOrderSet())) ? emptyOrderKey
: AggregationUtils.extractRowKey(row, _orderSetInfo.getOrderSet());
int aggCallsSize = _aggCalls.size();
for (int i = 0; i < aggCallsSize; i++) {
Expand Down Expand Up @@ -412,14 +415,46 @@ public Long merge(Object agg, @Nullable Object value) {
}
}

private static class MergeRank implements AggregationUtils.Merger {

@Override
public Long init(Object other, DataSchema.ColumnDataType dataType) {
return 1L;
}

@Override
public Long merge(Object left, Object right) {
// RANK always increase by the number of duplicate entries seen for the given ORDER BY key.
return ((Number) left).longValue() + ((Number) right).longValue();
}
}

private static class MergeDenseRank implements AggregationUtils.Merger {

@Override
public Long init(Object other, DataSchema.ColumnDataType dataType) {
return 1L;
}

@Override
public Long merge(Object left, Object right) {
long rightValueInLong = ((Number) right).longValue();
// DENSE_RANK always increase the rank by 1, irrespective of the number of duplicate ORDER BY keys seen
return (rightValueInLong == 0L) ? ((Number) left).longValue() : ((Number) left).longValue() + 1L;
}
}

private static class WindowAggregateAccumulator extends AggregationUtils.Accumulator {
private static final Map<String, Function<DataSchema.ColumnDataType, AggregationUtils.Merger>> WIN_AGG_MERGERS =
ImmutableMap.<String, Function<DataSchema.ColumnDataType, AggregationUtils.Merger>>builder()
.putAll(AggregationUtils.Accumulator.MERGERS)
.put("ROW_NUMBER", cdt -> new MergeRowNumber())
.put("RANK", cdt -> new MergeRank())
.put("DENSE_RANK", cdt -> new MergeDenseRank())
.build();

private final boolean _isPartitionByOnly;
private final boolean _isRankingWindowFunction;

// Fields needed only for RANGE frame type queries (ORDER BY)
private final Map<Key, OrderKeyResult> _orderByResults = new HashMap<>();
Expand All @@ -429,6 +464,7 @@ private static class WindowAggregateAccumulator extends AggregationUtils.Accumul
DataSchema inputSchema, OrderSetInfo orderSetInfo) {
super(aggCall, merger, functionName, inputSchema);
_isPartitionByOnly = CollectionUtils.isEmpty(orderSetInfo.getOrderSet()) || orderSetInfo.isPartitionByOnly();
_isRankingWindowFunction = RANKING_FUNCTION_NAMES.contains(functionName);
}

/**
Expand All @@ -452,7 +488,8 @@ public Object computeRowResultForCurrentRow(Key currentPartitionKey, Key previou
* RANGE key and not to the row ordering. This should only be called for RANGE type queries.
*/
public void accumulateRangeResults(Key key, Key orderKey, Object[] row) {
if (_isPartitionByOnly) {
// Ranking functions don't use the row value, thus cannot reuse the AggregationUtils accumulate function for them
if (_isPartitionByOnly && !_isRankingWindowFunction) {
accumulate(key, row);
return;
}
Expand All @@ -464,23 +501,29 @@ public void accumulateRangeResults(Key key, Key orderKey, Object[] row) {
: _orderByResults.get(key).getOrderByResults().get(previousOrderKeyIfPresent);
Object value = _inputRef == -1 ? _literal : row[_inputRef];

// The ranking functions do not depend on the actual value of the data, but are calculated based on the
// position of the data ordered by the ORDER BY key. Thus they need to be handled differently and require setting
// whether the rank has changed or not and if changed then by how much.
_orderByResults.putIfAbsent(key, new OrderKeyResult());
if (currentRes == null) {
value = _isRankingWindowFunction ? 0 : value;
_orderByResults.get(key).addOrderByResult(orderKey, _merger.init(value, _dataType));
} else {
Object mergedResult;
if (orderKey.equals(previousOrderKeyIfPresent)) {
value = _isRankingWindowFunction ? 0 : value;
mergedResult = _merger.merge(currentRes, value);
} else {
Object previousValue = _orderByResults.get(key).getOrderByResults().get(previousOrderKeyIfPresent);
value = _isRankingWindowFunction ? _orderByResults.get(key).getCountOfDuplicateOrderByKeys() : value;
mergedResult = _merger.merge(previousValue, value);
}
_orderByResults.get(key).addOrderByResult(orderKey, mergedResult);
}
}

public Object getRangeResultForKeys(Key key, Key orderKey) {
if (_isPartitionByOnly) {
if (_isPartitionByOnly && !_isRankingWindowFunction) {
return _results.get(key);
} else {
return _orderByResults.get(key).getOrderByResults().get(orderKey);
Expand All @@ -494,16 +537,21 @@ public Map<Key, OrderKeyResult> getRangeOrderByResults() {
static class OrderKeyResult {
final Map<Key, Object> _orderByResults;
Key _previousOrderByKey;
// Store the counts of duplicate ORDER BY keys seen for this PARTITION BY key for calculating RANK/DENSE_RANK
long _countOfDuplicateOrderByKeys;

OrderKeyResult() {
_orderByResults = new HashMap<>();
_previousOrderByKey = null;
_countOfDuplicateOrderByKeys = 0;
}

public void addOrderByResult(Key orderByKey, Object value) {
// We expect to get the rows in order based on the ORDER BY key so it is safe to blindly assign the
// current key as the previous key
_orderByResults.put(orderByKey, value);
_countOfDuplicateOrderByKeys = (_previousOrderByKey != null && _previousOrderByKey.equals(orderByKey))
? _countOfDuplicateOrderByKeys + 1 : 1;
_previousOrderByKey = orderByKey;
}

Expand All @@ -514,6 +562,10 @@ public Map<Key, Object> getOrderByResults() {
public Key getPreviousOrderByKey() {
return _previousOrderByKey;
}

public long getCountOfDuplicateOrderByKeys() {
return _countOfDuplicateOrderByKeys;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,12 @@ public void testShouldThrowOnUnknownAggFunction() {
}

@Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".*Unexpected aggregation "
+ "function name: RANK.*")
+ "function name: NTILE.*")
public void testShouldThrowOnUnknownRankAggFunction() {
// TODO: Remove this test when support is added for RANK functions
// TODO: Remove this test when support is added for NTILE function
// Given:
List<RexExpression> calls = ImmutableList.of(
new RexExpression.FunctionCall(SqlKind.RANK, FieldSpec.DataType.INT, "RANK", ImmutableList.of()));
new RexExpression.FunctionCall(SqlKind.RANK, FieldSpec.DataType.INT, "NTILE", ImmutableList.of()));
List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0));
DataSchema outSchema = new DataSchema(new String[]{"unknown"}, new DataSchema.ColumnDataType[]{DOUBLE});
DataSchema inSchema = new DataSchema(new String[]{"unknown"}, new DataSchema.ColumnDataType[]{DOUBLE});
Expand All @@ -406,6 +406,67 @@ public void testShouldThrowOnUnknownRankAggFunction() {
WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema);
}

@Test
public void testRankDenseRankRankingFunctions() {
// Given:
List<RexExpression> calls = ImmutableList.of(
new RexExpression.FunctionCall(SqlKind.RANK, FieldSpec.DataType.INT, "RANK", ImmutableList.of()),
new RexExpression.FunctionCall(SqlKind.DENSE_RANK, FieldSpec.DataType.INT, "DENSE_RANK", ImmutableList.of()));
List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0));
List<RexExpression> order = ImmutableList.of(new RexExpression.InputRef(1));

DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{INT, STRING});
// Input should be in sorted order on the order by key as SortExchange will handle pre-sorting the data
Mockito.when(_input.nextBlock())
.thenReturn(OperatorTestUtil.block(inSchema, new Object[]{3, "and"}, new Object[]{2, "bar"},
new Object[]{2, "foo"}, new Object[]{1, "foo"}))
.thenReturn(OperatorTestUtil.block(inSchema, new Object[]{1, "foo"}, new Object[]{2, "foo"},
new Object[]{1, "numb"}, new Object[]{2, "the"}, new Object[]{3, "true"}))
.thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());

DataSchema outSchema = new DataSchema(new String[]{"group", "arg", "rank", "dense_rank"},
new DataSchema.ColumnDataType[]{INT, STRING, LONG, LONG});

// When:
WindowAggregateOperator operator =
new WindowAggregateOperator(OperatorTestUtil.getDefaultContext(), _input, group, order,
Collections.emptyList(), Collections.emptyList(), calls, Integer.MIN_VALUE, 0,
WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema);

TransferableBlock result = operator.getNextBlock();
while (result.isNoOpBlock()) {
result = operator.getNextBlock();
}
TransferableBlock eosBlock = operator.getNextBlock();
List<Object[]> resultRows = result.getContainer();
Map<Integer, List<Object[]>> expectedPartitionToRowsMap = new HashMap<>();
expectedPartitionToRowsMap.put(1, Arrays.asList(new Object[]{1, "foo", 1L, 1L}, new Object[]{1, "foo", 1L, 1L},
new Object[]{1, "numb", 3L, 2L}));
expectedPartitionToRowsMap.put(2, Arrays.asList(new Object[]{2, "bar", 1L, 1L}, new Object[]{2, "foo", 2L, 2L},
new Object[]{2, "foo", 2L, 2L}, new Object[]{2, "the", 4L, 3L}));
expectedPartitionToRowsMap.put(3, Arrays.asList(new Object[]{3, "and", 1L, 1L}, new Object[]{3, "true", 2L, 2L}));

Integer previousPartitionKey = null;
Map<Integer, List<Object[]>> resultsPartitionToRowsMap = new HashMap<>();
for (Object[] row : resultRows) {
Integer currentPartitionKey = (Integer) row[0];
if (!currentPartitionKey.equals(previousPartitionKey)) {
Assert.assertFalse(resultsPartitionToRowsMap.containsKey(currentPartitionKey));
}
resultsPartitionToRowsMap.computeIfAbsent(currentPartitionKey, k -> new ArrayList<>()).add(row);
previousPartitionKey = currentPartitionKey;
}

resultsPartitionToRowsMap.forEach((key, value) -> {
List<Object[]> expectedRows = expectedPartitionToRowsMap.get(key);
Assert.assertEquals(value.size(), expectedRows.size());
for (int i = 0; i < value.size(); i++) {
Assert.assertEquals(value.get(i), expectedRows.get(i));
}
});
Assert.assertTrue(eosBlock.isEndOfStreamBlock(), "Second block is EOS (done processing)");
}

@Test
public void testRowNumberRankingFunction() {
// Given:
Expand All @@ -430,7 +491,7 @@ public void testRowNumberRankingFunction() {
WindowAggregateOperator operator =
new WindowAggregateOperator(OperatorTestUtil.getDefaultContext(), _input, group, order,
Collections.emptyList(), Collections.emptyList(), calls, Integer.MIN_VALUE, 0,
WindowNode.WindowFrameType.ROW, Collections.emptyList(), outSchema, inSchema);
WindowNode.WindowFrameType.ROWS, Collections.emptyList(), outSchema, inSchema);

TransferableBlock result = operator.getNextBlock();
while (result.isNoOpBlock()) {
Expand Down Expand Up @@ -561,7 +622,7 @@ public void testShouldThrowOnCustomFramesRows() {
WindowAggregateOperator operator =
new WindowAggregateOperator(OperatorTestUtil.getDefaultContext(), _input, group, Collections.emptyList(),
Collections.emptyList(), Collections.emptyList(), calls, Integer.MIN_VALUE, Integer.MAX_VALUE,
WindowNode.WindowFrameType.ROW, Collections.emptyList(), outSchema, inSchema);
WindowNode.WindowFrameType.ROWS, Collections.emptyList(), outSchema, inSchema);
}

@Test
Expand Down
Loading