Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,11 @@ public static Object clpDecode(String logtypeFieldName, String dictVarsFieldName
String defaultValue) {
throw new UnsupportedOperationException("Placeholder scalar function, should not reach here");
}

@ScalarFunction(names = {"arrayToMV", "array_to_mv"},
isPlaceholder = true)
public static String arrayToMV(Object multiValue) {
throw new UnsupportedOperationException("Placeholder scalar function, should not reach here");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public enum TransformFunctionType {
// date type conversion functions
CAST("cast"),

// object type
ARRAY_TO_MV("arrayToMV",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is confusing name. the data type is already MV running an ARRAY_TO_MV is a bit weird.
should we named it USE_AS_MV? and we can say that MV columns are by default USE_AS_ARRAY

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

USE_AS_MV might be confusing, since we repurposed MV as ARRAY in v2, ARRAY_TO_MV might be more explicit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we need to figure out a proper way b/c out put of a select is an array, but the table config / schema will still call this as MV column as convension. both have some confusion, but as long as the document is proper we should be good.

ReturnTypes.cascade(opBinding -> positionalComponentReturnType(opBinding, 0), SqlTypeTransforms.FORCE_NULLABLE),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's first create a component return type registry on @ScalarFunction so we dont have to modify the transform function side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually ignored previous comment, did a bit of research and it seems like registering TransformFunctionType without an actual impl is better than having to parse scalar function annotation, which doesn't really allow anything other than primitives

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will have a different pr for this.

OperandTypes.family(SqlTypeFamily.ARRAY), "array_to_mv"),

// string functions
JSONEXTRACTSCALAR("jsonExtractScalar",
ReturnTypes.cascade(opBinding -> positionalReturnTypeInferenceFromStringLiteral(opBinding, 2,
Expand Down Expand Up @@ -280,6 +285,13 @@ private static RelDataType positionalReturnTypeInferenceFromStringLiteral(SqlOpe
return opBinding.getTypeFactory().createSqlType(defaultSqlType);
}

private static RelDataType positionalComponentReturnType(SqlOperatorBinding opBinding, int pos) {
if (opBinding.getOperandCount() > pos) {
return opBinding.getOperandType(pos).getComponentType();
}
throw new IllegalArgumentException("Invalid number of arguments for function " + opBinding.getOperator().getName());
}

private static RelDataType dateTimeConverterReturnTypeInference(SqlOperatorBinding opBinding) {
int outputFormatPos = 2;
if (opBinding.getOperandCount() > outputFormatPos
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ public static TransformFunction get(ExpressionContext expression, Map<String, Co
return new IdentifierTransformFunction(columnName, columnContextMap.get(columnName));
case LITERAL:
return queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, expression.getLiteral(),
LiteralTransformFunction::new);
LiteralTransformFunction::new);
default:
throw new IllegalStateException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ public class QueryGenerator {
private final List<PredicateGenerator> _multistageSingleValuePredicateGenerators =
Arrays.asList(new SingleValueComparisonPredicateGenerator(), new SingleValueInPredicateGenerator(),
new SingleValueBetweenPredicateGenerator());
// TODO: add MultiValueBetweenPredicateGenerator back once the BETWEEEN AND operator is supported in multistage engine
private final List<PredicateGenerator> _multiValuePredicateGenerators =
Arrays.asList(new MultiValueComparisonPredicateGenerator(), new MultiValueInPredicateGenerator(),
new MultiValueBetweenPredicateGenerator());
Arrays.asList(new MultiValueComparisonPredicateGenerator(), new MultiValueInPredicateGenerator());

private final String _pinotTableName;
private final String _h2TableName;
Expand Down Expand Up @@ -351,10 +351,12 @@ private PredicateQueryFragment generatePredicate() {
if (!_columnToValueList.get(columnName).isEmpty()) {
if (!_multiValueColumnMaxNumElements.containsKey(columnName)) {
// Single-value column.
predicates.add(pickRandom(getSingleValuePredicateGenerators()).generatePredicate(columnName));
predicates.add(pickRandom(getSingleValuePredicateGenerators()).generatePredicate(columnName,
_useMultistageEngine));
} else if (!_skipMultiValuePredicates) {
// Multi-value column.
predicates.add(pickRandom(_multiValuePredicateGenerators).generatePredicate(columnName));
predicates.add(
pickRandom(_multiValuePredicateGenerators).generatePredicate(columnName, _useMultistageEngine));
}
}
}
Expand Down Expand Up @@ -407,10 +409,11 @@ private interface PredicateGenerator {
/**
* Generate a predicate query fragment on a column.
*
* @param columnName column name.
* @param columnName column name.
* @param useMultistageEngine
* @return generated predicate query fragment.
*/
QueryFragment generatePredicate(String columnName);
QueryFragment generatePredicate(String columnName, boolean useMultistageEngine);
}

/**
Expand Down Expand Up @@ -485,16 +488,47 @@ public AggregationQuery(List<String> aggregateColumnsAndFunctions, PredicateQuer

@Override
public String generatePinotQuery() {
List<String> pinotAggregateColumnAndFunctions =
(_useMultistageEngine && !_skipMultiValuePredicates) ? generatePinotMultistageQuery()
: _aggregateColumnsAndFunctions;
if (_groupColumns.isEmpty()) {
return joinWithSpaces("SELECT", StringUtils.join(_aggregateColumnsAndFunctions, ", "), "FROM", _pinotTableName,
_predicate.generatePinotQuery());
return joinWithSpaces("SELECT", StringUtils.join(pinotAggregateColumnAndFunctions, ", "), "FROM",
_pinotTableName, _predicate.generatePinotQuery());
} else {
return joinWithSpaces("SELECT", StringUtils.join(_aggregateColumnsAndFunctions, ", "), "FROM", _pinotTableName,
_predicate.generatePinotQuery(), "GROUP BY", StringUtils.join(_groupColumns, ", "),
return joinWithSpaces("SELECT", StringUtils.join(pinotAggregateColumnAndFunctions, ", "), "FROM",
_pinotTableName, _predicate.generatePinotQuery(), "GROUP BY", StringUtils.join(_groupColumns, ", "),
_havingPredicate.generatePinotQuery(), _limit.generatePinotQuery());
}
}

public List<String> generatePinotMultistageQuery() {
List<String> pinotAggregateColumnAndFunctions = new ArrayList<>();
for (String aggregateColumnAndFunction : _aggregateColumnsAndFunctions) {
String pinotAggregateFunction = aggregateColumnAndFunction;
String pinotAggregateColumnAndFunction = pinotAggregateFunction;
if (!pinotAggregateFunction.equals("COUNT(*)")) {
pinotAggregateFunction = pinotAggregateFunction.replace("(", "(`").replace(")", "`)");
}
if (!pinotAggregateFunction.contains("(")) {
pinotAggregateFunction = String.format("`%s`", pinotAggregateFunction);
}
if (AGGREGATION_FUNCTIONS.contains(pinotAggregateFunction.substring(0, 3))) {
// For multistage query, we need to explicit hoist the data type to avoid overflow.
String aggFunctionName = pinotAggregateFunction.substring(0, 3);
String replacedPinotAggregationFunction =
pinotAggregateFunction.replace(aggFunctionName + "(", aggFunctionName + "(CAST(");
if ("SUM".equalsIgnoreCase(aggFunctionName)) {
pinotAggregateColumnAndFunction = replacedPinotAggregationFunction.replace(")", " AS BIGINT))");
}
if ("AVG".equalsIgnoreCase(aggFunctionName)) {
pinotAggregateColumnAndFunction = replacedPinotAggregationFunction.replace(")", " AS DOUBLE))");
}
}
pinotAggregateColumnAndFunctions.add(pinotAggregateColumnAndFunction);
}
return pinotAggregateColumnAndFunctions;
}

@Override
public String generateH2Query() {
List<String> h2AggregateColumnAndFunctions = new ArrayList<>();
Expand Down Expand Up @@ -923,7 +957,7 @@ private String generateRandomValue(boolean generateInt) {
private class SingleValueComparisonPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
String columnValue = pickRandom(_columnToValueList.get(columnName));
String comparisonOperator = pickRandom(COMPARISON_OPERATORS);
return new StringQueryFragment(joinWithSpaces(columnName, comparisonOperator, columnValue),
Expand All @@ -937,7 +971,7 @@ public QueryFragment generatePredicate(String columnName) {
private class SingleValueInPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);

int numValues = Math.min(RANDOM.nextInt(MAX_NUM_IN_CLAUSE_VALUES) + 1, columnValues.size());
Expand All @@ -964,7 +998,7 @@ public QueryFragment generatePredicate(String columnName) {
private class SingleValueBetweenPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);
String leftValue = pickRandom(columnValues);
String rightValue = pickRandom(columnValues);
Expand All @@ -981,7 +1015,7 @@ private class SingleValueRegexPredicateGenerator implements PredicateGenerator {
Random _random = new Random();

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);
String value = pickRandom(columnValues);
// do regex only for string type
Expand All @@ -1008,7 +1042,7 @@ public QueryFragment generatePredicate(String columnName) {
private class MultiValueComparisonPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
String columnValue = pickRandom(_columnToValueList.get(columnName));
String comparisonOperator = pickRandom(COMPARISON_OPERATORS);

Expand All @@ -1024,7 +1058,8 @@ public QueryFragment generatePredicate(String columnName) {
joinWithSpaces(String.format("%s[%d]", columnName, i), comparisonOperator, columnValue));
}

return new StringQueryFragment(joinWithSpaces(columnName, comparisonOperator, columnValue),
return new StringQueryFragment(
joinWithSpaces(generateMultiValueColumn(columnName, useMultistageEngine), comparisonOperator, columnValue),
generateH2QueryConditionPredicate(h2ComparisonClauses));
}
}
Expand All @@ -1036,7 +1071,7 @@ public QueryFragment generatePredicate(String columnName) {
private class MultiValueInPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);

int numValues = Math.min(RANDOM.nextInt(MAX_NUM_IN_CLAUSE_VALUES) + 1, columnValues.size());
Expand All @@ -1052,7 +1087,8 @@ public QueryFragment generatePredicate(String columnName) {
h2InClauses.add(String.format("%s[%d] IN (%s)", columnName, i, inValues));
}

return new StringQueryFragment(String.format("%s IN (%s)", columnName, inValues),
return new StringQueryFragment(
String.format("%s IN (%s)", generateMultiValueColumn(columnName, useMultistageEngine), inValues),
generateH2QueryConditionPredicate(h2InClauses));
}
}
Expand All @@ -1063,7 +1099,7 @@ public QueryFragment generatePredicate(String columnName) {
private class MultiValueBetweenPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);
String leftValue = pickRandom(columnValues);
String rightValue = pickRandom(columnValues);
Expand All @@ -1074,13 +1110,24 @@ public QueryFragment generatePredicate(String columnName) {
h2ComparisonClauses.add(String.format("%s[%d] BETWEEN %s AND %s", columnName, i, leftValue, rightValue));
}

return new StringQueryFragment(String.format("%s BETWEEN %s AND %s", columnName, leftValue, rightValue),
return new StringQueryFragment(
String.format("%s BETWEEN %s AND %s", generateMultiValueColumn(columnName, useMultistageEngine), leftValue,
rightValue),
generateH2QueryConditionPredicate(h2ComparisonClauses));
}
}

private String generateMultiValueColumn(String columnName, boolean useMultistageEngine) {
if (useMultistageEngine) {
return String.format("ARRAY_TO_MV(%s)", columnName);
}
return columnName;
}

private static String generateH2QueryConditionPredicate(List<String> conditionList) {
return generateH2QueryConditionPredicate(conditionList, " OR ");
}

private static String generateH2QueryConditionPredicate(List<String> conditionList, String separator) {
return String.format("( %s )", StringUtils.join(conditionList, separator));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,18 @@ protected void cleanupTestTableDataManager(String tableNameWithType) {
}, 600_000L, "Failed to delete table data managers");
}

/**
* Test features supported in V2 Multi-stage Engine.
* - Some V1 features will not be supported.
* - Some V1 features will be added as V2 engine feature development progresses.
* @throws Exception
*/
public void testHardcodedQueriesMultiStage()
throws Exception {
testHardcodedQueriesCommon();
}

/**
* Test hard-coded queries.
* @throws Exception
*/
public void testHardcodedQueries()
throws Exception {
testHardcodedQueriesCommon();
testHardCodedQueriesV1();
if (useMultiStageQueryEngine()) {
testHardcodedQueriesV2();
} else {
testHardCodedQueriesV1();
}
}

/**
Expand Down Expand Up @@ -282,6 +275,29 @@ private void testHardcodedQueriesCommon()
testQuery(query, h2Query);
}

private void testHardcodedQueriesV2()
throws Exception {
String query;
String h2Query;

query =
"SELECT DistanceGroup FROM mytable WHERE \"Month\" BETWEEN 1 AND 1 AND arrayToMV(DivAirportSeqIDs) IN "
+ "(1078102, 1142303, 1530402, 1172102, 1291503) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10";
h2Query =
"SELECT DistanceGroup FROM mytable WHERE `Month` BETWEEN 1 AND 1 AND (DivAirportSeqIDs[1] IN (1078102, "
+ "1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[2] IN (1078102, 1142303, 1530402, 1172102, "
+ "1291503) OR DivAirportSeqIDs[3] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR "
+ "DivAirportSeqIDs[4] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[5] IN "
+ "(1078102, 1142303, 1530402, 1172102, 1291503)) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10000";
testQuery(query, h2Query);

query = "SELECT MIN(ArrDelayMinutes), AVG(CAST(DestCityMarketID AS DOUBLE)) FROM mytable WHERE DivArrDelay < 196";
h2Query =
"SELECT MIN(CAST(`ArrDelayMinutes` AS DOUBLE)), AVG(CAST(`DestCityMarketID` AS DOUBLE)) FROM mytable WHERE "
+ "`DivArrDelay` < 196";
testQuery(query, h2Query);
}

private void testHardCodedQueriesV1()
throws Exception {
String query;
Expand All @@ -295,17 +311,6 @@ private void testHardCodedQueriesV1()
"SELECT CAST(CAST(ArrTime AS varchar) AS LONG) FROM mytable WHERE DaysSinceEpoch <> 16312 AND Carrier = 'DL' "
+ "ORDER BY ArrTime DESC";
testQuery(query);
// TODO: move to common when multistage support MV columns
query =
"SELECT DistanceGroup FROM mytable WHERE \"Month\" BETWEEN 1 AND 1 AND DivAirportSeqIDs IN (1078102, 1142303,"
+ " 1530402, 1172102, 1291503) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10";
h2Query =
"SELECT DistanceGroup FROM mytable WHERE `Month` BETWEEN 1 AND 1 AND (DivAirportSeqIDs[1] IN (1078102, "
+ "1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[2] IN (1078102, 1142303, 1530402, 1172102, "
+ "1291503) OR DivAirportSeqIDs[3] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR "
+ "DivAirportSeqIDs[4] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[5] IN "
+ "(1078102, 1142303, 1530402, 1172102, 1291503)) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10000";
testQuery(query, h2Query);

// Non-Standard SQL syntax:
// IN_ID_SET
Expand Down Expand Up @@ -472,8 +477,13 @@ protected void testGeneratedQueries(boolean withMultiValues, boolean useMultista
for (int i = 0; i < numQueriesToGenerate; i++) {
QueryGenerator.Query query = queryGenerator.generateQuery();
if (useMultistageEngine) {
// multistage engine follows standard SQL thus should use H2 query string for testing.
testQuery(query.generateH2Query().replace("`", "\""), query.generateH2Query());
if (withMultiValues) {
// For multistage query with MV columns, we need to use Pinot query string for testing.
testQuery(query.generatePinotQuery().replace("`", "\""), query.generateH2Query());
} else {
// multistage engine follows standard SQL thus should use H2 query string for testing.
testQuery(query.generateH2Query().replace("`", "\""), query.generateH2Query());
}
} else {
testQuery(query.generatePinotQuery(), query.generateH2Query());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.util.TestUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -95,17 +96,17 @@ protected boolean useMultiStageQueryEngine() {

@Test
@Override
public void testHardcodedQueriesMultiStage()
public void testHardcodedQueries()
throws Exception {
super.testHardcodedQueriesMultiStage();
super.testHardcodedQueries();
}

@Test
@Override
public void testGeneratedQueries()
throws Exception {
// test multistage engine, currently we don't support MV columns.
super.testGeneratedQueries(false, true);
super.testGeneratedQueries(true, true);
}

@Test
Expand Down Expand Up @@ -485,6 +486,25 @@ public void testLiteralOnlyFunc()
assertEquals(results.get(10).asText(), "hello!");
}

@Test
public void testMultiValueColumnGroupBy()
throws Exception {
String pinotQuery = "SELECT count(*), arrayToMV(RandomAirports) FROM mytable "
+ "GROUP BY arrayToMV(RandomAirports)";
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
}

@Test
public void testMultiValueColumnGroupByOrderBy()
throws Exception {
String pinotQuery = "SELECT count(*), arrayToMV(RandomAirports) FROM mytable "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it work if I run

SELECT count(*), arrayToMV(RandomAirports) 
FROM mytable 
WHERE Dest IN (SELECT Dest FROM myTable GROUP BY Dest HAVING count(*) > 10)

(later when we implemented the scalar function wrapper)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need a group by arrayToMV(RandomAirports) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works at leaf stage not intermediate stage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds good

+ "GROUP BY arrayToMV(RandomAirports) "
+ "ORDER BY arrayToMV(RandomAirports) DESC";
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
}

@AfterClass
public void tearDown()
throws Exception {
Expand Down
Loading