Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ public abstract class BaseTransformFunction implements TransformFunction {
new TransformResultMetadata(DataType.JSON, true, false);
protected static final TransformResultMetadata BYTES_SV_NO_DICTIONARY_METADATA =
new TransformResultMetadata(DataType.BYTES, true, false);
protected static final TransformResultMetadata LONG_MV_NO_DICTIONARY_METADATA =
new TransformResultMetadata(DataType.LONG, false, false);
protected static final TransformResultMetadata DOUBLE_MV_NO_DICTIONARY_METADATA =
new TransformResultMetadata(DataType.DOUBLE, false, false);
protected static final TransformResultMetadata INT_MV_NO_DICTIONARY_METADATA =
new TransformResultMetadata(DataType.INT, false, false);
protected static final TransformResultMetadata FLOAT_MV_NO_DICTIONARY_METADATA =
new TransformResultMetadata(DataType.FLOAT, false, false);

protected int[] _intValuesSV;
protected long[] _longValuesSV;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,40 +47,58 @@ public void init(List<TransformFunction> arguments, Map<String, DataSource> data

_transformFunction = arguments.get(0);
TransformFunction castFormatTransformFunction = arguments.get(1);
boolean isSVCol = _transformFunction.getResultMetadata().isSingleValue();

if (castFormatTransformFunction instanceof LiteralTransformFunction) {
String targetType = ((LiteralTransformFunction) castFormatTransformFunction).getLiteral().toUpperCase();
switch (targetType) {
case "INT":
case "INTEGER":
_resultMetadata = INT_SV_NO_DICTIONARY_METADATA;
_resultMetadata = isSVCol ? INT_SV_NO_DICTIONARY_METADATA : INT_MV_NO_DICTIONARY_METADATA;
break;
case "LONG":
_resultMetadata = LONG_SV_NO_DICTIONARY_METADATA;
_resultMetadata = isSVCol ? LONG_SV_NO_DICTIONARY_METADATA : LONG_MV_NO_DICTIONARY_METADATA;
break;
case "FLOAT":
_resultMetadata = FLOAT_SV_NO_DICTIONARY_METADATA;
_resultMetadata = isSVCol ? FLOAT_SV_NO_DICTIONARY_METADATA : FLOAT_MV_NO_DICTIONARY_METADATA;
break;
case "DOUBLE":
_resultMetadata = DOUBLE_SV_NO_DICTIONARY_METADATA;
_resultMetadata = isSVCol ? DOUBLE_SV_NO_DICTIONARY_METADATA : DOUBLE_MV_NO_DICTIONARY_METADATA;
break;
case "DECIMAL":
case "BIGDECIMAL":
case "BIG_DECIMAL":
if (!isSVCol) {
// TODO: MV cast to BIG_DECIMAL type
throw new IllegalArgumentException(
"Cast is not supported on multi-value column to target type: " + targetType);
}
_resultMetadata = BIG_DECIMAL_SV_NO_DICTIONARY_METADATA;
break;
case "BOOL":
case "BOOLEAN":
if (!isSVCol) {
throw new IllegalArgumentException(
"Cast is not supported on multi-value column to target type: " + targetType);
}
_resultMetadata = BOOLEAN_SV_NO_DICTIONARY_METADATA;
break;
case "TIMESTAMP":
if (!isSVCol) {
throw new IllegalArgumentException(
"Cast is not supported on multi-value column to target type: " + targetType);
}
_resultMetadata = TIMESTAMP_SV_NO_DICTIONARY_METADATA;
break;
case "STRING":
case "VARCHAR":
_resultMetadata = STRING_SV_NO_DICTIONARY_METADATA;
_resultMetadata = isSVCol ? STRING_SV_NO_DICTIONARY_METADATA : STRING_MV_NO_DICTIONARY_METADATA;
break;
case "JSON":
if (!isSVCol) {
throw new IllegalArgumentException(
"Cast is not supported on multi-value column to target type: " + targetType);
}
_resultMetadata = JSON_SV_NO_DICTIONARY_METADATA;
break;
default:
Expand All @@ -96,6 +114,55 @@ public TransformResultMetadata getResultMetadata() {
return _resultMetadata;
}

@Override
public double[][] transformToDoubleValuesMV(ProjectionBlock projectionBlock) {
DataType resultStoredType = _resultMetadata.getDataType().getStoredType();
if (resultStoredType == DataType.DOUBLE) {
return _transformFunction.transformToDoubleValuesMV(projectionBlock);
} else {
return super.transformToDoubleValuesMV(projectionBlock);
}
}

@Override
public String[][] transformToStringValuesMV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() == DataType.STRING) {
return _transformFunction.transformToStringValuesMV(projectionBlock);
} else {
return super.transformToStringValuesMV(projectionBlock);
}
}

@Override
public int[][] transformToIntValuesMV(ProjectionBlock projectionBlock) {
DataType resultStoredType = _resultMetadata.getDataType().getStoredType();
if (resultStoredType == DataType.INT) {
return _transformFunction.transformToIntValuesMV(projectionBlock);
} else {
return super.transformToIntValuesMV(projectionBlock);
}
}

@Override
public float[][] transformToFloatValuesMV(ProjectionBlock projectionBlock) {
DataType resultStoredType = _resultMetadata.getDataType().getStoredType();
if (resultStoredType == DataType.FLOAT) {
return _transformFunction.transformToFloatValuesMV(projectionBlock);
} else {
return super.transformToFloatValuesMV(projectionBlock);
}
}

@Override
public long[][] transformToLongValuesMV(ProjectionBlock projectionBlock) {
DataType resultStoredType = _resultMetadata.getDataType().getStoredType();
if (resultStoredType == DataType.LONG) {
return _transformFunction.transformToLongValuesMV(projectionBlock);
} else {
return super.transformToLongValuesMV(projectionBlock);
}
}

@Override
public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() == DataType.INT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public abstract class BaseTransformFunctionTest {
protected static final String DOUBLE_MV_COLUMN = "doubleMV";
protected static final String STRING_MV_COLUMN = "stringMV";
protected static final String STRING_ALPHANUM_MV_COLUMN = "stringAlphaNumMV";
protected static final String STRING_LONG_MV_COLUMN = "stringLongMV";
protected static final String TIME_COLUMN = "timeColumn";
protected static final String TIMESTAMP_COLUMN = "timestampColumn";
protected static final String JSON_COLUMN = "json";
Expand All @@ -99,6 +100,7 @@ public abstract class BaseTransformFunctionTest {
protected final double[][] _doubleMVValues = new double[NUM_ROWS][];
protected final String[][] _stringMVValues = new String[NUM_ROWS][];
protected final String[][] _stringAlphaNumericMVValues = new String[NUM_ROWS][];
protected final String[][] _stringLongFormatMVValues = new String[NUM_ROWS][];
protected final long[] _timeValues = new long[NUM_ROWS];
protected final String[] _jsonValues = new String[NUM_ROWS];

Expand Down Expand Up @@ -129,6 +131,7 @@ public void setUp()
_doubleMVValues[i] = new double[numValues];
_stringMVValues[i] = new String[numValues];
_stringAlphaNumericMVValues[i] = new String[numValues];
_stringLongFormatMVValues[i] = new String[numValues];

for (int j = 0; j < numValues; j++) {
_intMVValues[i][j] = 1 + RANDOM.nextInt(MAX_MULTI_VALUE);
Expand All @@ -137,6 +140,7 @@ public void setUp()
_doubleMVValues[i][j] = 1 + RANDOM.nextDouble();
_stringMVValues[i][j] = df.format(_intSVValues[i] * RANDOM.nextDouble());
_stringAlphaNumericMVValues[i][j] = RandomStringUtils.randomAlphanumeric(26);
_stringLongFormatMVValues[i][j] = df.format(_intSVValues[i] * RANDOM.nextLong());
}

// Time in the past year
Expand All @@ -160,6 +164,7 @@ public void setUp()
map.put(DOUBLE_MV_COLUMN, ArrayUtils.toObject(_doubleMVValues[i]));
map.put(STRING_MV_COLUMN, _stringMVValues[i]);
map.put(STRING_ALPHANUM_MV_COLUMN, _stringAlphaNumericMVValues[i]);
map.put(STRING_LONG_MV_COLUMN, _stringLongFormatMVValues[i]);
map.put(TIMESTAMP_COLUMN, _timeValues[i]);
map.put(TIME_COLUMN, _timeValues[i]);
_jsonValues[i] = JsonUtils.objectToJsonNode(map).toString();
Expand All @@ -184,6 +189,7 @@ public void setUp()
.addMultiValueDimension(DOUBLE_MV_COLUMN, FieldSpec.DataType.DOUBLE)
.addMultiValueDimension(STRING_MV_COLUMN, FieldSpec.DataType.STRING)
.addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN, FieldSpec.DataType.STRING)
.addMultiValueDimension(STRING_LONG_MV_COLUMN, FieldSpec.DataType.STRING)
.addDateTime(TIMESTAMP_COLUMN, FieldSpec.DataType.TIMESTAMP, "1:MILLISECONDS:EPOCH", "1:MILLISECONDS")
.addTime(new TimeGranularitySpec(FieldSpec.DataType.LONG, TimeUnit.MILLISECONDS, TIME_COLUMN), null).build();
TableConfig tableConfig =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
package org.apache.pinot.core.operator.transform.function;

import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.RequestContextUtils;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.utils.ArrayCopyUtils;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand All @@ -30,6 +33,71 @@

public class CastTransformFunctionTest extends BaseTransformFunctionTest {

@Test
public void testCastTransformFunctionMV() {
ExpressionContext expression =
RequestContextUtils.getExpression(String.format("CAST(%s AS LONG)", STRING_LONG_MV_COLUMN));
TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
long[][] expectedLongValues = new long[NUM_ROWS][];
ArrayCopyUtils.copy(_stringLongFormatMVValues, expectedLongValues, NUM_ROWS);
testCastTransformFunctionMV(transformFunction, expectedLongValues);

expression = RequestContextUtils.getExpression(
String.format("CAST(CAST(CAST(%s AS LONG) as DOUBLE) as INT)", STRING_LONG_MV_COLUMN));
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
long[][] innerLongValues = new long[NUM_ROWS][];
ArrayCopyUtils.copy(_stringLongFormatMVValues, innerLongValues, NUM_ROWS);
double[][] innerDoubleValues = new double[NUM_ROWS][];
ArrayCopyUtils.copy(innerLongValues, innerDoubleValues, NUM_ROWS);
int[][] expectedIntValues = new int[NUM_ROWS][];
ArrayCopyUtils.copy(innerDoubleValues, expectedIntValues, NUM_ROWS);
testCastTransformFunctionMV(transformFunction, expectedIntValues);

expression =
RequestContextUtils.getExpression(String.format("CAST(CAST(%s AS INT) as FLOAT)", FLOAT_MV_COLUMN));
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
int[][] innerLayerInt = new int[NUM_ROWS][];
ArrayCopyUtils.copy(_floatMVValues, innerLayerInt, NUM_ROWS);
float[][] expectedFloatValues = new float[NUM_ROWS][];
ArrayCopyUtils.copy(innerLayerInt, expectedFloatValues, NUM_ROWS);
testCastTransformFunctionMV(transformFunction, expectedFloatValues);

expression = RequestContextUtils.getExpression(
String.format("CAST(CAST(CAST(%s AS FLOAT) as INT) as STRING)", INT_MV_COLUMN));
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
float[][] innerFloatValues = new float[NUM_ROWS][];
ArrayCopyUtils.copy(_intMVValues, innerFloatValues, NUM_ROWS);
innerLayerInt = new int[NUM_ROWS][];
ArrayCopyUtils.copy(innerFloatValues, innerLayerInt, NUM_ROWS);
String[][] expectedStringValues = new String[NUM_ROWS][];
ArrayCopyUtils.copy(innerLayerInt, expectedStringValues, NUM_ROWS);
testCastTransformFunctionMV(transformFunction, expectedStringValues);

expression = RequestContextUtils.getExpression(String.format("arrayMax(cAst(%s AS INT))", DOUBLE_MV_COLUMN));
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
FieldSpec.DataType resultDataType = transformFunction.getResultMetadata().getDataType();
Assert.assertEquals(resultDataType, FieldSpec.DataType.INT);

// checks that arraySum triggers transformToDoubleMV in cast function which correctly cast to INT
expression = RequestContextUtils.getExpression(String.format("arraySum(cAst(%s AS INT))", DOUBLE_MV_COLUMN));
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
int[][] afterCast = new int[NUM_ROWS][];
ArrayCopyUtils.copy(_doubleMVValues, afterCast, NUM_ROWS);
double[] expectedArraySums = new double[NUM_ROWS];
for (int i = 0; i < NUM_ROWS; i++) {
expectedArraySums[i] = Arrays.stream(afterCast[i]).sum();
}
testTransformFunction(transformFunction, expectedArraySums);
}

@Test
public void testCastTransformFunction() {
ExpressionContext expression =
Expand Down Expand Up @@ -130,4 +198,94 @@ public void testCastTransformFunction() {
testTransformFunction(transformFunction, expectedBigDecimalValues);
assertEquals(expectedBigDecimalValues, bigDecimalScalarValues);
}

private void testCastTransformFunctionMV(TransformFunction transformFunction, int[][] expectedValues) {
int[][] intMVValues = transformFunction.transformToIntValuesMV(_projectionBlock);
long[][] longMVValues = transformFunction.transformToLongValuesMV(_projectionBlock);
float[][] floatMVValues = transformFunction.transformToFloatValuesMV(_projectionBlock);
double[][] doubleMVValues = transformFunction.transformToDoubleValuesMV(_projectionBlock);
String[][] stringMVValues = transformFunction.transformToStringValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
int rowLen = expectedValues[i].length;
for (int j = 0; j < rowLen; j++) {
Assert.assertEquals(intMVValues[i][j], expectedValues[i][j]);
Assert.assertEquals(longMVValues[i][j], (long) expectedValues[i][j]);
Assert.assertEquals(floatMVValues[i][j], (float) expectedValues[i][j]);
Assert.assertEquals(doubleMVValues[i][j], (double) expectedValues[i][j]);
Assert.assertEquals(stringMVValues[i][j], Integer.toString(expectedValues[i][j]));
}
}
}

private void testCastTransformFunctionMV(TransformFunction transformFunction, long[][] expectedValues) {
int[][] intMVValues = transformFunction.transformToIntValuesMV(_projectionBlock);
long[][] longMVValues = transformFunction.transformToLongValuesMV(_projectionBlock);
float[][] floatMVValues = transformFunction.transformToFloatValuesMV(_projectionBlock);
double[][] doubleMVValues = transformFunction.transformToDoubleValuesMV(_projectionBlock);
String[][] stringMVValues = transformFunction.transformToStringValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
int rowLen = expectedValues[i].length;
for (int j = 0; j < rowLen; j++) {
Assert.assertEquals(intMVValues[i][j], (int) expectedValues[i][j]);
Assert.assertEquals(longMVValues[i][j], expectedValues[i][j]);
Assert.assertEquals(floatMVValues[i][j], (float) expectedValues[i][j]);
Assert.assertEquals(doubleMVValues[i][j], (double) expectedValues[i][j]);
Assert.assertEquals(stringMVValues[i][j], Long.toString(expectedValues[i][j]));
}
}
}

private void testCastTransformFunctionMV(TransformFunction transformFunction, float[][] expectedValues) {
int[][] intMVValues = transformFunction.transformToIntValuesMV(_projectionBlock);
long[][] longMVValues = transformFunction.transformToLongValuesMV(_projectionBlock);
float[][] floatMVValues = transformFunction.transformToFloatValuesMV(_projectionBlock);
double[][] doubleMVValues = transformFunction.transformToDoubleValuesMV(_projectionBlock);
String[][] stringMVValues = transformFunction.transformToStringValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
int rowLen = expectedValues[i].length;
for (int j = 0; j < rowLen; j++) {
Assert.assertEquals(intMVValues[i][j], (int) expectedValues[i][j]);
Assert.assertEquals(longMVValues[i][j], (long) expectedValues[i][j]);
Assert.assertEquals(floatMVValues[i][j], expectedValues[i][j]);
Assert.assertEquals(doubleMVValues[i][j], (double) expectedValues[i][j]);
Assert.assertEquals(stringMVValues[i][j], Float.toString(expectedValues[i][j]));
}
}
}

private void testCastTransformFunctionMV(TransformFunction transformFunction, double[][] expectedValues) {
int[][] intMVValues = transformFunction.transformToIntValuesMV(_projectionBlock);
long[][] longMVValues = transformFunction.transformToLongValuesMV(_projectionBlock);
float[][] floatMVValues = transformFunction.transformToFloatValuesMV(_projectionBlock);
double[][] doubleMVValues = transformFunction.transformToDoubleValuesMV(_projectionBlock);
String[][] stringMVValues = transformFunction.transformToStringValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
int rowLen = expectedValues[i].length;
for (int j = 0; j < rowLen; j++) {
Assert.assertEquals(intMVValues[i][j], (int) expectedValues[i][j]);
Assert.assertEquals(longMVValues[i][j], (long) expectedValues[i][j]);
Assert.assertEquals(floatMVValues[i][j], (float) expectedValues[i][j]);
Assert.assertEquals(doubleMVValues[i][j], expectedValues[i][j]);
Assert.assertEquals(stringMVValues[i][j], Double.toString(expectedValues[i][j]));
}
}
}

private void testCastTransformFunctionMV(TransformFunction transformFunction, String[][] expectedValues) {
int[][] intMVValues = transformFunction.transformToIntValuesMV(_projectionBlock);
long[][] longMVValues = transformFunction.transformToLongValuesMV(_projectionBlock);
float[][] floatMVValues = transformFunction.transformToFloatValuesMV(_projectionBlock);
double[][] doubleMVValues = transformFunction.transformToDoubleValuesMV(_projectionBlock);
String[][] stringMVValues = transformFunction.transformToStringValuesMV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
int rowLen = expectedValues[i].length;
for (int j = 0; j < rowLen; j++) {
Assert.assertEquals(intMVValues[i][j], Integer.parseInt(expectedValues[i][j]));
Assert.assertEquals(longMVValues[i][j], Long.parseLong(expectedValues[i][j]));
Assert.assertEquals(floatMVValues[i][j], Float.parseFloat(expectedValues[i][j]));
Assert.assertEquals(doubleMVValues[i][j], Double.parseDouble(expectedValues[i][j]));
Assert.assertEquals(stringMVValues[i][j], expectedValues[i][j]);
}
}
}
}
Loading