diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java index 9720300ab29e14..468fd07284701a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java @@ -26,6 +26,7 @@ import org.apache.doris.catalog.DatabaseIf; import org.apache.doris.catalog.FunctionSet; import org.apache.doris.catalog.OlapTable; +import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.Table; import org.apache.doris.catalog.TableIf; import org.apache.doris.catalog.TableIf.TableType; @@ -601,8 +602,8 @@ public void analyze(Analyzer analyzer) throws UserException { } } - whereClauseRewrite(); if (whereClause != null) { + whereClauseRewrite(); if (checkGroupingFn(whereClause)) { throw new AnalysisException("grouping operations are not allowed in WHERE."); } @@ -851,6 +852,9 @@ private void whereClauseRewrite() { } else { whereClause = new BoolLiteral(true); } + } else if (!whereClause.getType().isBoolean()) { + whereClause = new CastExpr(TypeDef.create(PrimitiveType.BOOLEAN), whereClause); + whereClause.setType(Type.BOOLEAN); } } @@ -1263,6 +1267,9 @@ private void analyzeAggregation(Analyzer analyzer) throws AnalysisException { havingClauseAfterAnalyzed = havingClause.substitute(aliasSMap, analyzer, false); } havingClauseAfterAnalyzed = rewriteQueryExprByMvColumnExpr(havingClauseAfterAnalyzed, analyzer); + if (!havingClauseAfterAnalyzed.getType().isBoolean()) { + havingClauseAfterAnalyzed = havingClauseAfterAnalyzed.castTo(Type.BOOLEAN); + } havingClauseAfterAnalyzed.checkReturnsBool("HAVING clause", true); if (groupingInfo != null) { groupingInfo.substituteGroupingFn(Arrays.asList(havingClauseAfterAnalyzed), analyzer); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/TableRef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/TableRef.java index a99e4cf9597dca..fcfbd39b445a3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/TableRef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/TableRef.java @@ -22,6 +22,7 @@ import org.apache.doris.catalog.Env; import org.apache.doris.catalog.TableIf; +import org.apache.doris.catalog.Type; import org.apache.doris.catalog.external.HMSExternalTable; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.ErrorCode; @@ -651,6 +652,9 @@ public void analyzeJoin(Analyzer analyzer) throws AnalysisException { analyzer.setVisibleSemiJoinedTuple(semiJoinedTupleId); onClause.analyze(analyzer); analyzer.setVisibleSemiJoinedTuple(null); + if (!onClause.getType().isBoolean()) { + onClause = onClause.castTo(Type.BOOLEAN); + } onClause.checkReturnsBool("ON clause", true); if (onClause.contains(Expr.isAggregatePredicate())) { throw new AnalysisException( diff --git a/regression-test/suites/query_p0/cast/test_cast.groovy b/regression-test/suites/query_p0/cast/test_cast.groovy index bfe4a87989c3a7..59d86eb80e2f28 100644 --- a/regression-test/suites/query_p0/cast/test_cast.groovy +++ b/regression-test/suites/query_p0/cast/test_cast.groovy @@ -31,4 +31,60 @@ suite('test_cast') { sql "select cast(${datetime} as int), cast(${datetime} as bigint), cast(${datetime} as float), cast(${datetime} as double)" result([[869930357, 20200101123445l, ((float) 20200101123445l), ((double) 20200101123445l)]]) } -} \ No newline at end of file + + def tbl = "test_cast" + + sql """ DROP TABLE IF EXISTS ${tbl}""" + sql """ + CREATE TABLE IF NOT EXISTS ${tbl} ( + `k0` int + ) + DISTRIBUTED BY HASH(`k0`) BUCKETS 5 properties("replication_num" = "1") + """ + sql """ INSERT INTO ${tbl} VALUES (101);""" + + test { + sql "select * from ${tbl} where case when k0 = 101 then 1 else 0 end" + result([[101]]) + } + + test { + sql "select * from ${tbl} where case when k0 = 101 then 12 else 0 end" + result([[101]]) + } + + test { + sql "select * from ${tbl} where case when k0 = 101 then -12 else 0 end" + result([[101]]) + } + + test { + sql "select * from ${tbl} where case when k0 = 101 then 0 else 1 end" + result([]) + } + + test { + sql "select * from ${tbl} where case when k0 != 101 then 0 else 1 end" + result([[101]]) + } + + test { + sql "select * from ${tbl} where case when k0 = 101 then '1' else 0 end" + result([[101]]) + } + + test { + sql "select * from ${tbl} where case when k0 = 101 then '12' else 0 end" + result([]) + } + + test { + sql "select * from ${tbl} where case when k0 = 101 then 'false' else 0 end" + result([]) + } + + test { + sql "select * from ${tbl} where case when k0 = 101 then 'true' else 1 end" + result([[101]]) + } +}