Skip to content

Commit

Permalink
Fix bug in UDTF auto-casting where arguments involving binary ops wou…
Browse files Browse the repository at this point in the history
…ld not typecheck properly. (#7305)

Signed-off-by: Misiu Godfrey <misiu.godfrey@kraken.mapd.com>
  • Loading branch information
brenocfg authored and misiugodfrey committed Aug 28, 2023
1 parent 886a67d commit d392b7b
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,24 @@ EXTENSION_NOINLINE_HOST int32_t row_copier2__cpu__(const Column<double>& input_c
return result;
}

EXTENSION_NOINLINE_HOST int32_t
row_copier_columnlist__cpu__(TableFunctionManager& mgr,
const ColumnList<double>& cols,
Column<double>& output_col) {
int32_t output_row_count = 0;
for (int i = 0; i < cols.numCols(); ++i) {
output_row_count += cols[i].size();
}
mgr.set_output_row_size(output_row_count);
int idx = 0;
for (int i = 0; i < cols.numCols(); ++i) {
for (int j = 0; j < cols[i].size(); ++j) {
output_col[idx++] = cols[i][j];
}
}
return output_row_count;
}

#endif // #ifndef __CUDACC__

EXTENSION_NOINLINE int32_t row_copier_text(const Column<TextEncodingDict>& input_col,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,7 @@ ct_to_polygon__cpu_(TableFunctionManager& mgr,
/*
UDTF: row_copier(Column<double>, RowMultiplier) -> Column<double>
UDTF: row_copier_text(Column<TextEncodingDict>, RowMultiplier) -> Column<TextEncodingDict> | input_id=args<0>
UDTF: row_copier_columnlist__cpu__(TableFunctionManager, ColumnList<double> cols) -> Column<double>
UDTF: row_copier2__cpu__(Column<double>, int) -> Column<double>, Column<double>
*/
// clang-format on
Expand All @@ -1420,6 +1421,11 @@ EXTENSION_NOINLINE_HOST int32_t row_copier2__cpu__(const Column<double>& input_c
Column<double>& output_col,
Column<double>& output_col2);

EXTENSION_NOINLINE_HOST int32_t
row_copier_columnlist__cpu__(TableFunctionManager& mgr,
const ColumnList<double>& cols,
Column<double>& output_col);

#endif // #ifndef __CUDACC__

EXTENSION_NOINLINE int32_t row_copier_text(const Column<TextEncodingDict>& input_col,
Expand Down
35 changes: 35 additions & 0 deletions Tests/TableFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3473,6 +3473,41 @@ TEST_F(TableFunctions, CalciteAutoCasting) {
dt);
ASSERT_EQ(result_qe724->rowCount(), size_t(2));
}

// tests for QE-788
// Calcite will consider result of some binary ops involving literals (such as FLOAT +
// literal) as a wider type (such as DOUBLE), even though they are still passed to the
// backend as FLOAT. The casting algorithm should cast operands of such ops to the
// wider type explicitly, so that they typecheck.
{
const auto result_qe788_column = run_multiple_agg(
"SELECT out0 from table(row_copier(CURSOR(select f + 1.0 from tf_test)));", dt);
const auto result_qe788_column_perm = run_multiple_agg(
"SELECT out0 from table(row_copier(CURSOR(select 1.0 + f from tf_test)));", dt);
const auto expected_result_qe788_column = run_multiple_agg(
"SELECT out0 from table(row_copier(CURSOR(select f + CAST(1.0 as DOUBLE) from "
"tf_test)));",
dt);
assert_equal<double>(result_qe788_column, expected_result_qe788_column);
assert_equal<double>(result_qe788_column_perm, expected_result_qe788_column);
}
{
const auto result_qe788_columnlist = run_multiple_agg(
"SELECT out0 from table(row_copier_columnlist(CURSOR(select f + 1.0, f + 2.0 "
"from tf_test)));",
dt);
const auto result_qe788_columnlist_perm = run_multiple_agg(
"SELECT out0 from table(row_copier_columnlist(CURSOR(select 1.0 + f, 2.0 + f "
"from tf_test)));",
dt);
const auto expected_result_qe788_columnlist = run_multiple_agg(
"SELECT out0 from table(row_copier_columnlist(cols => CURSOR(select f + "
"CAST(1.0 as DOUBLE), f + CAST(2.0 as DOUBLE) from tf_test)));",
dt);
assert_equal<double>(result_qe788_columnlist, expected_result_qe788_columnlist);
assert_equal<double>(result_qe788_columnlist_perm,
expected_result_qe788_columnlist);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package com.mapd.calcite.parser;

import static com.mapd.parser.server.ExtensionFunction.*;

import static org.apache.calcite.runtime.Resources.BaseMessage;
import static org.apache.calcite.runtime.Resources.ExInst;

import com.mapd.calcite.parser.HeavyDBSqlOperatorTable.ExtTableFunction;

import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.runtime.Resources;
import org.apache.calcite.sql.SqlBasicCall;
Expand All @@ -25,7 +22,6 @@
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlNameMatchers;
import org.apache.calcite.sql.validate.SqlValidator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import org.apache.calcite.rel.type.RelDataTypeFactoryImpl;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlBinaryOperator;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlKind;
Expand Down Expand Up @@ -128,6 +130,7 @@ private int calculateScoreForCursorOperand(SqlNode cursorOperand,
RelDataType formalRelType = toRelDataType(extType, factory);
RelDataType actualRelType = factory.createTypeWithNullability(
validator.deriveType(scope, selectOperand), true);
RelDataType widerType = getWiderTypeForTwo(formalRelType, actualRelType, false);

if (formalRelType.getSqlTypeName() == SqlTypeName.COLUMN_LIST) {
ExtArgumentType colListSubtype = getValueType(extType);
Expand All @@ -145,8 +148,7 @@ private int calculateScoreForCursorOperand(SqlNode cursorOperand,
}
}

RelDataType widerType =
getWiderTypeForTwo(actualRelType, formalSubtype, false);
widerType = getWiderTypeForTwo(actualRelType, formalSubtype, false);
if (!SqlTypeUtil.sameNamedType(actualRelType, formalSubtype)) {
if (widerType == null || widerType == actualRelType) {
// no common type, or actual type is wider than formal
Expand Down Expand Up @@ -174,13 +176,17 @@ private int calculateScoreForCursorOperand(SqlNode cursorOperand,
} else {
score += getScoreForTypes(widerType, actualRelType, true);
}
} else {
// Calcite considers the result of some binary operations as a wider type,
// even though they're not passed to the backend as such (FLOAT + literal
// == DOUBLE, for instance). We penalize these so that literal operands
// are casted regardless. (See QE-788)
score += shouldCoerceBinOpOperand(curOperand, widerType, scope) ? 100 : 0;
}
colListSize++;
}
iActual += colListSize - 1;
} else if (actualRelType != formalRelType) {
RelDataType widerType =
getWiderTypeForTwo(formalRelType, actualRelType, false);
if (widerType == null) {
// no common wider type
return -1;
Expand All @@ -197,6 +203,12 @@ private int calculateScoreForCursorOperand(SqlNode cursorOperand,
} else {
score += getScoreForTypes(widerType, actualRelType, true);
}
} else {
// Calcite considers the result of some binary operations as a wider type,
// even though they're not passed to the backend as such (FLOAT + literal
// == DOUBLE, for instance). We penalize these so that literal operands
// are casted regardless. (See QE-788)
score += shouldCoerceBinOpOperand(selectOperand, widerType, scope) ? 100 : 0;
}
}

Expand Down Expand Up @@ -306,6 +318,7 @@ private void coerceCursorType(SqlValidatorScope scope,
ExtArgumentType extType = formalFieldTypes.get(iFormal);
RelDataType formalRelType = toRelDataType(extType, factory);
RelDataType actualRelType = validator.deriveType(scope, selectOperand);
RelDataType widerType = getWiderTypeForTwo(formalRelType, actualRelType, false);

if (isColumnArrayType(extType) || isColumnListArrayType(extType)) {
// Arrays can't be casted so don't bother trying
Expand All @@ -316,8 +329,7 @@ private void coerceCursorType(SqlValidatorScope scope,
if (formalRelType.getSqlTypeName() == SqlTypeName.COLUMN_LIST) {
ExtArgumentType colListSubtype = getValueType(extType);
RelDataType formalSubtype = toRelDataType(colListSubtype, factory);
RelDataType widerType =
getWiderTypeForTwo(actualRelType, formalSubtype, false);
widerType = getWiderTypeForTwo(actualRelType, formalSubtype, false);

int colListSize = 0;
int numFormalArgumentsLeft = (formalFieldTypes.size() - 1) - iFormal;
Expand All @@ -336,15 +348,16 @@ private void coerceCursorType(SqlValidatorScope scope,
iActual + colListSize,
widerType);
}
} else if (shouldCoerceBinOpOperand(curOperand, widerType, scope)) {
coerceBinOpOperand((SqlBasicCall) curOperand, widerType, scope);
}

updateValidatedType(
newValidatedTypeList, selectNode, iActual + colListSize);
colListSize++;
}
iActual += colListSize - 1;
} else if (actualRelType != formalRelType) {
RelDataType widerType =
getWiderTypeForTwo(formalRelType, actualRelType, false);
if (!SqlTypeUtil.isTimestamp(widerType)
&& SqlTypeUtil.sameNamedType(actualRelType, formalRelType)) {
updateValidatedType(newValidatedTypeList, selectNode, iActual);
Expand All @@ -355,6 +368,9 @@ private void coerceCursorType(SqlValidatorScope scope,
}
updateValidatedType(newValidatedTypeList, selectNode, iActual);
} else {
if (shouldCoerceBinOpOperand(selectOperand, widerType, scope)) {
coerceBinOpOperand((SqlBasicCall) selectOperand, widerType, scope);
}
// keep old validated type for argument that was not coerced
updateValidatedType(newValidatedTypeList, selectNode, iActual);
}
Expand Down Expand Up @@ -421,7 +437,7 @@ private void updateValidatedType(
RelDataType newType = validator.getValidatedNodeType(operand);
if (operand instanceof SqlCall) {
SqlCall asCall = (SqlCall) operand;
if (asCall.getOperator().kind == SqlKind.AS) {
if (asCall.getOperator().getKind() == SqlKind.AS) {
newType = validator.getValidatedNodeType(asCall.operand(0));
}
}
Expand Down Expand Up @@ -467,4 +483,53 @@ else if (SqlTypeUtil.inCharFamily(targetType)) {
}
}
}

/**
* Coerces operands of a binary operator to the given @targetType.
*/
private void coerceBinOpOperand(
SqlBasicCall binOp, RelDataType targetType, SqlValidatorScope scope) {
if (binOp.getKind() == SqlKind.AS) {
binOp = binOp.operand(0);
}
coerceOperandType(scope, binOp, 0, targetType);
coerceOperandType(scope, binOp, 1, targetType);
}

/**
* Determines if a binary operator's operands need to be coerced explicitly. Calcite
* considers the output of some binary operators involving literals as a wider type. As
* a consequence, some operations such as FLOAT + <float_literal> are type-inferenced to
* DOUBLE, and will typecheck against UDTFs that accept DOUBLE columns. However, these
* parameters are still sent to the backend as FLOATs. This method identifies these
* occurrences, so that they can be casted explicitly.
*/
private boolean shouldCoerceBinOpOperand(
SqlNode op, RelDataType targetType, SqlValidatorScope scope) {
if (op instanceof SqlBasicCall) {
SqlBasicCall asCall = (SqlBasicCall) op;
if (asCall.getOperator().getKind() == SqlKind.AS) {
SqlNode op2 = asCall.operand(0);
if (op2 instanceof SqlBasicCall) {
asCall = (SqlBasicCall) op2;
} else {
return false;
}
}
if (asCall.getOperator() instanceof SqlBinaryOperator) {
SqlNode lhs = asCall.operand(0);
SqlNode rhs = asCall.operand(1);
RelDataType lhsType = validator.deriveType(scope, lhs);
RelDataType rhsType = validator.deriveType(scope, rhs);
// if neither operand is already the wider type, and at least one is a literal,
// depending on precedence the result might falsely typecheck as the wider type
if (lhsType != targetType && rhsType != targetType
&& (lhs.getKind() == SqlKind.LITERAL
|| rhs.getKind() == SqlKind.LITERAL)) {
return true;
}
}
}
return false;
}
}

0 comments on commit d392b7b

Please sign in to comment.