Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LASTWITHTIME aggregate function support #7315 #7584

Merged
merged 5 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public void testGetAggregationFunctionType() {
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("SuM"), AggregationFunctionType.SUM);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("AvG"), AggregationFunctionType.AVG);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MoDe"), AggregationFunctionType.MODE);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("LaStWiThTiMe"),
AggregationFunctionType.LASTWITHTIME);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MiNmAxRaNgE"),
AggregationFunctionType.MINMAXRANGE);
Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("DiStInCtCoUnT"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,14 @@
import org.apache.pinot.core.query.utils.idset.IdSet;
import org.apache.pinot.core.query.utils.idset.IdSets;
import org.apache.pinot.segment.local.customobject.AvgPair;
import org.apache.pinot.segment.local.customobject.DoubleValueTimePair;
import org.apache.pinot.segment.local.customobject.FloatValueTimePair;
import org.apache.pinot.segment.local.customobject.IntValueTimePair;
import org.apache.pinot.segment.local.customobject.LongValueTimePair;
import org.apache.pinot.segment.local.customobject.MinMaxRangePair;
import org.apache.pinot.segment.local.customobject.QuantileDigest;
import org.apache.pinot.segment.local.customobject.StringValueTimePair;
import org.apache.pinot.segment.local.customobject.ValueTimePair;
import org.apache.pinot.segment.local.utils.GeometrySerializer;
import org.apache.pinot.spi.utils.BigDecimalUtils;
import org.apache.pinot.spi.utils.ByteArray;
Expand Down Expand Up @@ -109,7 +115,12 @@ public enum ObjectType {
Int2LongMap(23),
Long2LongMap(24),
Float2LongMap(25),
Double2LongMap(26);
Double2LongMap(26),
IntValueTimePair(27),
LongValueTimePair(28),
FloatValueTimePair(29),
DoubleValueTimePair(30),
StringValueTimePair(31);
private final int _value;

ObjectType(int value) {
Expand Down Expand Up @@ -178,6 +189,16 @@ public static ObjectType getObjectType(Object value) {
return ObjectType.IdSet;
} else if (value instanceof List) {
return ObjectType.List;
} else if (value instanceof IntValueTimePair) {
return ObjectType.IntValueTimePair;
} else if (value instanceof LongValueTimePair) {
return ObjectType.LongValueTimePair;
} else if (value instanceof FloatValueTimePair) {
return ObjectType.FloatValueTimePair;
} else if (value instanceof DoubleValueTimePair) {
return ObjectType.DoubleValueTimePair;
} else if (value instanceof StringValueTimePair) {
return ObjectType.StringValueTimePair;
} else {
throw new IllegalArgumentException("Unsupported type of value: " + value.getClass().getSimpleName());
}
Expand Down Expand Up @@ -330,6 +351,101 @@ public MinMaxRangePair deserialize(ByteBuffer byteBuffer) {
}
};

public static final ObjectSerDe<? extends ValueTimePair<Integer>> INT_VAL_TIME_PAIR_SER_DE
= new ObjectSerDe<IntValueTimePair>() {

@Override
public byte[] serialize(IntValueTimePair intValueTimePair) {
return intValueTimePair.toBytes();
}

@Override
public IntValueTimePair deserialize(byte[] bytes) {
return IntValueTimePair.fromBytes(bytes);
}

@Override
public IntValueTimePair deserialize(ByteBuffer byteBuffer) {
return IntValueTimePair.fromByteBuffer(byteBuffer);
}
};

public static final ObjectSerDe<? extends ValueTimePair<Long>> LONG_VAL_TIME_PAIR_SER_DE
= new ObjectSerDe<LongValueTimePair>() {

@Override
public byte[] serialize(LongValueTimePair longValueTimePair) {
return longValueTimePair.toBytes();
}

@Override
public LongValueTimePair deserialize(byte[] bytes) {
return LongValueTimePair.fromBytes(bytes);
}

@Override
public LongValueTimePair deserialize(ByteBuffer byteBuffer) {
return LongValueTimePair.fromByteBuffer(byteBuffer);
}
};

public static final ObjectSerDe<? extends ValueTimePair<Float>> FLOAT_VAL_TIME_PAIR_SER_DE
= new ObjectSerDe<FloatValueTimePair>() {

@Override
public byte[] serialize(FloatValueTimePair floatValueTimePair) {
return floatValueTimePair.toBytes();
}

@Override
public FloatValueTimePair deserialize(byte[] bytes) {
return FloatValueTimePair.fromBytes(bytes);
}

@Override
public FloatValueTimePair deserialize(ByteBuffer byteBuffer) {
return FloatValueTimePair.fromByteBuffer(byteBuffer);
}
};

public static final ObjectSerDe<? extends ValueTimePair<Double>> DOUBLE_VAL_TIME_PAIR_SER_DE
= new ObjectSerDe<DoubleValueTimePair>() {

@Override
public byte[] serialize(DoubleValueTimePair doubleValueTimePair) {
return doubleValueTimePair.toBytes();
}

@Override
public DoubleValueTimePair deserialize(byte[] bytes) {
return DoubleValueTimePair.fromBytes(bytes);
}

@Override
public DoubleValueTimePair deserialize(ByteBuffer byteBuffer) {
return DoubleValueTimePair.fromByteBuffer(byteBuffer);
}
};

public static final ObjectSerDe<? extends ValueTimePair<String>> STRING_VAL_TIME_PAIR_SER_DE
= new ObjectSerDe<StringValueTimePair>() {

@Override
public byte[] serialize(StringValueTimePair stringValueTimePair) {
return stringValueTimePair.toBytes();
}

@Override
public StringValueTimePair deserialize(byte[] bytes) {
return StringValueTimePair.fromBytes(bytes);
}

@Override
public StringValueTimePair deserialize(ByteBuffer byteBuffer) {
return StringValueTimePair.fromByteBuffer(byteBuffer);
}
};

public static final ObjectSerDe<HyperLogLog> HYPER_LOG_LOG_SER_DE = new ObjectSerDe<HyperLogLog>() {

@Override
Expand Down Expand Up @@ -1047,7 +1163,12 @@ public Double2LongOpenHashMap deserialize(ByteBuffer byteBuffer) {
INT_2_LONG_MAP_SER_DE,
LONG_2_LONG_MAP_SER_DE,
FLOAT_2_LONG_MAP_SER_DE,
DOUBLE_2_LONG_MAP_SER_DE
DOUBLE_2_LONG_MAP_SER_DE,
INT_VAL_TIME_PAIR_SER_DE,
LONG_VAL_TIME_PAIR_SER_DE,
FLOAT_VAL_TIME_PAIR_SER_DE,
DOUBLE_VAL_TIME_PAIR_SER_DE,
STRING_VAL_TIME_PAIR_SER_DE
};
//@formatter:on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.exception.BadQueryRequestException;
Expand Down Expand Up @@ -156,6 +158,45 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
return new AvgAggregationFunction(firstArgument);
case MODE:
return new ModeAggregationFunction(arguments);
case LASTWITHTIME:
if (arguments.size() > 1) {
ExpressionContext timeCol = arguments.get(1);
String dataType = arguments.get(2).getIdentifier();
DataSchema.ColumnDataType columnDataType = DataSchema.ColumnDataType.valueOf(dataType.toUpperCase());
switch (columnDataType) {
case BOOLEAN:
case INT:
return new LastIntValueWithTimeAggregationFunction(
firstArgument,
timeCol,
ObjectSerDeUtils.INT_VAL_TIME_PAIR_SER_DE,
columnDataType == DataSchema.ColumnDataType.BOOLEAN);
case LONG:
return new LastLongValueWithTimeAggregationFunction(
firstArgument,
timeCol,
ObjectSerDeUtils.LONG_VAL_TIME_PAIR_SER_DE);
case FLOAT:
return new LastFloatValueWithTimeAggregationFunction(
firstArgument,
timeCol,
ObjectSerDeUtils.FLOAT_VAL_TIME_PAIR_SER_DE);
case DOUBLE:
return new LastDoubleValueWithTimeAggregationFunction(
firstArgument,
timeCol,
ObjectSerDeUtils.DOUBLE_VAL_TIME_PAIR_SER_DE);
case STRING:
return new LastStringValueWithTimeAggregationFunction(
firstArgument,
timeCol,
ObjectSerDeUtils.STRING_VAL_TIME_PAIR_SER_DE);
default:
throw new IllegalArgumentException("Unsupported Value Type for LastWithTime Function:" + dataType);
}
} else {
throw new IllegalArgumentException("Two arguments are required for LastWithTime Function.");
}
case MINMAXRANGE:
return new MinMaxRangeAggregationFunction(firstArgument);
case DISTINCTCOUNT:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.query.aggregation.function;

import java.util.Arrays;
import java.util.List;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.segment.local.customobject.DoubleValueTimePair;
import org.apache.pinot.segment.local.customobject.ValueTimePair;

/**
* This function is used for LastWithTime calculations for data column with double type.
* <p>The function can be used as LastWithTime(dataExpression, timeExpression, 'double')
* <p>Following arguments are supported:
* <ul>
* <li>dataExpression: expression that contains the double data column to be calculated last on</li>
* <li>timeExpression: expression that contains the column to be used to decide which data is last, can be any
* Numeric column</li>
* </ul>
*/
@SuppressWarnings({"rawtypes", "unchecked"})
public class LastDoubleValueWithTimeAggregationFunction extends LastWithTimeAggregationFunction<Double> {

private final static ValueTimePair<Double> DEFAULT_VALUE_TIME_PAIR
= new DoubleValueTimePair(Double.NaN, Long.MIN_VALUE);

public LastDoubleValueWithTimeAggregationFunction(
ExpressionContext dataCol,
ExpressionContext timeCol,
ObjectSerDeUtils.ObjectSerDe<? extends ValueTimePair<Double>> objectSerDe) {
super(dataCol, timeCol, objectSerDe);
}

@Override
public List<ExpressionContext> getInputExpressions() {
return Arrays.asList(_expression, _timeCol, ExpressionContext.forLiteral("Long"));
}

@Override
public ValueTimePair<Double> constructValueTimePair(Double value, long time) {
return new DoubleValueTimePair(value, time);
}

@Override
public ValueTimePair<Double> getDefaultValueTimePair() {
return DEFAULT_VALUE_TIME_PAIR;
}

@Override
public void updateResultWithRawData(int length, AggregationResultHolder aggregationResultHolder,
BlockValSet blockValSet, BlockValSet timeValSet) {
ValueTimePair<Double> defaultValueTimePair = getDefaultValueTimePair();
Double lastData = defaultValueTimePair.getValue();
long lastTime = defaultValueTimePair.getTime();
double [] doubleValues = blockValSet.getDoubleValuesSV();
long[] timeValues = timeValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
double data = doubleValues[i];
long time = timeValues[i];
if (time >= lastTime) {
lastTime = time;
lastData = data;
}
}
setAggregationResult(aggregationResultHolder, lastData, lastTime);
}

@Override
public void updateGroupResultWithRawDataSv(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
BlockValSet blockValSet, BlockValSet timeValSet) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
long[] timeValues = timeValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
double data = doubleValues[i];
long time = timeValues[i];
setGroupByResult(groupKeyArray[i], groupByResultHolder, data, time);
}
}

@Override
public void updateGroupResultWithRawDataMv(int length,
int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
BlockValSet blockValSet,
BlockValSet timeValSet) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
long[] timeValues = timeValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
double value = doubleValues[i];
long time = timeValues[i];
for (int groupKey : groupKeysArray[i]) {
setGroupByResult(groupKey, groupByResultHolder, value, time);
}
}
}

@Override
public String getResultColumnName() {
return getType().getName().toLowerCase() + "(" + _expression + "," + _timeCol + ", Double)";
}

@Override
public ColumnDataType getFinalResultColumnType() {
return ColumnDataType.DOUBLE;
}
}
Loading