Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.OrderByExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.request.context.QueryContext;

Expand Down Expand Up @@ -54,7 +55,7 @@ public abstract class IndexedTable extends BaseTable {
*
* @param dataSchema Data schema of the table
* @param queryContext Query context
* @param resultSize Number of records to keep in the final result after calling {@link #finish(boolean)}
* @param resultSize Number of records to keep in the final result after calling {@link #finish(boolean, boolean)}
* @param trimSize Number of records to keep when trimming the table
* @param trimThreshold Trim the table when the number of records exceeds the threshold
* @param lookupMap Map from keys to records
Expand Down Expand Up @@ -144,7 +145,7 @@ protected void resize() {
}

@Override
public void finish(boolean sort) {
public void finish(boolean sort, boolean storeFinalResult) {
if (_hasOrderBy) {
long startTimeNs = System.nanoTime();
_topRecords = _tableResizer.getTopRecords(_lookupMap, _resultSize, sort);
Expand All @@ -154,6 +155,21 @@ public void finish(boolean sort) {
} else {
_topRecords = _lookupMap.values();
}
// TODO: Directly return final result in _tableResizer.getTopRecords to avoid extracting final result multiple times
if (storeFinalResult) {
ColumnDataType[] columnDataTypes = _dataSchema.getColumnDataTypes();
int numAggregationFunctions = _aggregationFunctions.length;
for (int i = 0; i < numAggregationFunctions; i++) {
columnDataTypes[i + _numKeyColumns] = _aggregationFunctions[i].getFinalResultColumnType();
}
for (Record record : _topRecords) {
Object[] values = record.getValues();
for (int i = 0; i < numAggregationFunctions; i++) {
int colId = i + _numKeyColumns;
values[colId] = _aggregationFunctions[i].extractFinalResult(values[colId]);
}
}
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@ public interface Table {
*/
Iterator<Record> iterator();

default void finish(boolean sort) {
finish(sort, false);
}

/**
* Finish any pre exit processing
* @param sort sort the final results if true
* @param storeFinalResult whether to store final aggregation result
*/
void finish(boolean sort);
void finish(boolean sort, boolean storeFinalResult);

/**
* Returns the data schema of the table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
*/
package org.apache.pinot.core.operator.blocks.results;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.List;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
Expand All @@ -27,14 +29,15 @@
import org.apache.pinot.core.common.datatable.DataTableFactory;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.spi.utils.ByteArray;
import org.apache.pinot.spi.utils.NullValueUtils;
import org.roaringbitmap.RoaringBitmap;


/**
* Results block for aggregation queries.
*/
@SuppressWarnings("rawtypes")
@SuppressWarnings({"rawtypes", "unchecked"})
public class AggregationResultsBlock extends BaseResultsBlock {
private final AggregationFunction[] _aggregationFunctions;
private final List<Object> _results;
Expand All @@ -55,14 +58,17 @@ public List<Object> getResults() {
@Override
public DataTable getDataTable(QueryContext queryContext)
throws Exception {
boolean returnFinalResult = queryContext.isServerReturnFinalResult();

// Extract result column name and type from each aggregation function
int numColumns = _aggregationFunctions.length;
String[] columnNames = new String[numColumns];
ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
for (int i = 0; i < numColumns; i++) {
AggregationFunction aggregationFunction = _aggregationFunctions[i];
columnNames[i] = aggregationFunction.getColumnName();
columnDataTypes[i] = aggregationFunction.getIntermediateResultColumnType();
columnDataTypes[i] = returnFinalResult ? aggregationFunction.getFinalResultColumnType()
: aggregationFunction.getIntermediateResultColumnType();
}

// Build the data table.
Expand All @@ -76,11 +82,20 @@ public DataTable getDataTable(QueryContext queryContext)
dataTableBuilder.startRow();
for (int i = 0; i < numColumns; i++) {
Object result = _results.get(i);
if (result == null && columnDataTypes[i] != ColumnDataType.OBJECT) {
result = NullValueUtils.getDefaultNullValue(columnDataTypes[i].toDataType());
nullBitmaps[i].add(0);
if (!returnFinalResult) {
if (result == null && columnDataTypes[i] != ColumnDataType.OBJECT) {
result = NullValueUtils.getDefaultNullValue(columnDataTypes[i].toDataType());
nullBitmaps[i].add(0);
}
setIntermediateResult(dataTableBuilder, columnDataTypes, i, result);
} else {
result = _aggregationFunctions[i].extractFinalResult(result);
if (result == null) {
result = NullValueUtils.getDefaultNullValue(columnDataTypes[i].toDataType());
nullBitmaps[i].add(0);
}
setFinalResult(dataTableBuilder, columnDataTypes, i, result);
}
setResult(dataTableBuilder, columnNames, columnDataTypes, i, result);
}
dataTableBuilder.finishRow();
for (RoaringBitmap nullBitmap : nullBitmaps) {
Expand All @@ -89,7 +104,13 @@ public DataTable getDataTable(QueryContext queryContext)
} else {
dataTableBuilder.startRow();
for (int i = 0; i < numColumns; i++) {
setResult(dataTableBuilder, columnNames, columnDataTypes, i, _results.get(i));
Object result = _results.get(i);
if (!returnFinalResult) {
setIntermediateResult(dataTableBuilder, columnDataTypes, i, result);
} else {
result = _aggregationFunctions[i].extractFinalResult(result);
setFinalResult(dataTableBuilder, columnDataTypes, i, result);
}
}
dataTableBuilder.finishRow();
}
Expand All @@ -99,23 +120,56 @@ public DataTable getDataTable(QueryContext queryContext)
return dataTable;
}

private void setResult(DataTableBuilder dataTableBuilder, String[] columnNames, ColumnDataType[] columnDataTypes,
int index, Object result)
private void setIntermediateResult(DataTableBuilder dataTableBuilder, ColumnDataType[] columnDataTypes, int index,
Object result)
throws IOException {
ColumnDataType columnDataType = columnDataTypes[index];
switch (columnDataType) {
case LONG:
dataTableBuilder.setColumn(index, ((Number) result).longValue());
dataTableBuilder.setColumn(index, (long) result);
break;
case DOUBLE:
dataTableBuilder.setColumn(index, ((Double) result).doubleValue());
dataTableBuilder.setColumn(index, (double) result);
break;
case OBJECT:
dataTableBuilder.setColumn(index, result);
break;
default:
throw new UnsupportedOperationException(
"Unsupported aggregation column data type: " + columnDataType + " for column: " + columnNames[index]);
throw new IllegalStateException("Illegal column data type in intermediate result: " + columnDataType);
}
}

private void setFinalResult(DataTableBuilder dataTableBuilder, ColumnDataType[] columnDataTypes, int index,
Object result)
throws IOException {
ColumnDataType columnDataType = columnDataTypes[index];
switch (columnDataType) {
case INT:
dataTableBuilder.setColumn(index, (int) result);
break;
case LONG:
dataTableBuilder.setColumn(index, (long) result);
break;
case FLOAT:
dataTableBuilder.setColumn(index, (float) result);
break;
case DOUBLE:
dataTableBuilder.setColumn(index, (double) result);
break;
case BIG_DECIMAL:
dataTableBuilder.setColumn(index, (BigDecimal) result);
break;
case STRING:
dataTableBuilder.setColumn(index, result.toString());
break;
case BYTES:
dataTableBuilder.setColumn(index, (ByteArray) result);
break;
case DOUBLE_ARRAY:
dataTableBuilder.setColumn(index, ((DoubleArrayList) result).elements());
break;
default:
throw new IllegalStateException("Illegal column data type in final result: " + columnDataType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.pinot.core.operator.blocks.results;

import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Collection;
Expand Down Expand Up @@ -193,14 +194,11 @@ private void setDataTableColumn(ColumnDataType storedColumnDataType, DataTableBu
dataTableBuilder.setColumn(columnIndex, (BigDecimal) value);
break;
case STRING:
dataTableBuilder.setColumn(columnIndex, (String) value);
dataTableBuilder.setColumn(columnIndex, value.toString());
break;
case BYTES:
dataTableBuilder.setColumn(columnIndex, (ByteArray) value);
break;
case OBJECT:
dataTableBuilder.setColumn(columnIndex, value);
break;
case INT_ARRAY:
dataTableBuilder.setColumn(columnIndex, (int[]) value);
break;
Expand All @@ -211,11 +209,18 @@ private void setDataTableColumn(ColumnDataType storedColumnDataType, DataTableBu
dataTableBuilder.setColumn(columnIndex, (float[]) value);
break;
case DOUBLE_ARRAY:
dataTableBuilder.setColumn(columnIndex, (double[]) value);
if (value instanceof DoubleArrayList) {
dataTableBuilder.setColumn(columnIndex, ((DoubleArrayList) value).elements());
} else {
dataTableBuilder.setColumn(columnIndex, (double[]) value);
}
break;
case STRING_ARRAY:
dataTableBuilder.setColumn(columnIndex, (String[]) value);
break;
case OBJECT:
dataTableBuilder.setColumn(columnIndex, value);
break;
default:
throw new IllegalStateException("Unsupported stored type: " + storedColumnDataType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public GroupByOrderByCombineOperator(List<Operator> operators, QueryContext quer
int minTrimSize = queryContext.getMinServerGroupTrimSize();
if (minTrimSize > 0) {
int limit = queryContext.getLimit();
if (queryContext.getOrderByExpressions() != null || queryContext.getHavingFilter() != null) {
if ((!queryContext.isServerReturnFinalResult() && queryContext.getOrderByExpressions() != null)
|| queryContext.getHavingFilter() != null) {
_trimSize = GroupByUtils.getTableCapacity(limit, minTrimSize);
} else {
// TODO: Keeping only 'LIMIT' groups can cause inaccurate result because the groups are randomly selected
Expand Down Expand Up @@ -252,7 +253,11 @@ protected BaseResultsBlock mergeResults()
}

IndexedTable indexedTable = _indexedTable;
indexedTable.finish(false);
if (!_queryContext.isServerReturnFinalResult()) {
indexedTable.finish(false);
} else {
indexedTable.finish(true, true);
}
GroupByResultsBlock mergedBlock = new GroupByResultsBlock(indexedTable);
mergedBlock.setNumGroupsLimitReached(_numGroupsLimitReached);
mergedBlock.setNumResizes(indexedTable.getNumResizes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.common.utils.DataTable;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.segment.spi.AggregationFunctionType;
Expand Down Expand Up @@ -115,4 +117,48 @@ public static Map<ExpressionContext, BlockValSet> getBlockValSetMap(
BlockValSet blockValSet = transformBlock.getBlockValueSet(aggregationFunctionColumnPair.toColumnName());
return Collections.singletonMap(expression, blockValSet);
}

/**
* Reads the intermediate result from the {@link DataTable}.
*/
public static Object getIntermediateResult(DataTable dataTable, ColumnDataType columnDataType, int rowId, int colId) {
switch (columnDataType) {
case LONG:
return dataTable.getLong(rowId, colId);
case DOUBLE:
return dataTable.getDouble(rowId, colId);
case OBJECT:
return dataTable.getObject(rowId, colId);
default:
throw new IllegalStateException("Illegal column data type in intermediate result: " + columnDataType);
}
}

/**
* Reads the converted final result from the {@link DataTable}. It should be equivalent to running
* {@link AggregationFunction#extractFinalResult(Object)} and {@link ColumnDataType#convert(Object)}.
*/
public static Object getConvertedFinalResult(DataTable dataTable, ColumnDataType columnDataType, int rowId,
int colId) {
switch (columnDataType) {
case INT:
return dataTable.getInt(rowId, colId);
case LONG:
return dataTable.getLong(rowId, colId);
case FLOAT:
return dataTable.getFloat(rowId, colId);
case DOUBLE:
return dataTable.getDouble(rowId, colId);
case BIG_DECIMAL:
return dataTable.getBigDecimal(rowId, colId);
case STRING:
return dataTable.getString(rowId, colId);
case BYTES:
return dataTable.getBytes(rowId, colId).getBytes();
case DOUBLE_ARRAY:
return dataTable.getDoubleArray(rowId, colId);
default:
throw new IllegalStateException("Illegal column data type in final result: " + columnDataType);
}
}
}
Loading