Skip to content

Commit fe1f456

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-29837][SQL] PostgreSQL dialect: cast to boolean
### What changes were proposed in this pull request? Make SparkSQL's `cast to boolean` behavior be consistent with PostgreSQL when spark.sql.dialect is configured as PostgreSQL. ### Why are the changes needed? SparkSQL and PostgreSQL have a lot different cast behavior between types by default. We should make SparkSQL's cast behavior be consistent with PostgreSQL when `spark.sql.dialect` is configured as PostgreSQL. ### Does this PR introduce any user-facing change? Yes. If user switches to PostgreSQL dialect now, they will * get an exception if they input a invalid string, e.g "erut", while they get `null` before; * get an exception if they input `TimestampType`, `DateType`, `LongType`, `ShortType`, `ByteType`, `DecimalType`, `DoubleType`, `FloatType` values, while they get `true` or `false` result before. And here're evidences for those unsupported types from PostgreSQL: timestamp: ``` postgres=# select cast(cast('2019-11-11' as timestamp) as boolean); ERROR: cannot cast type timestamp without time zone to boolean ``` date: ``` postgres=# select cast(cast('2019-11-11' as date) as boolean); ERROR: cannot cast type date to boolean ``` bigint: ``` postgres=# select cast(cast('20191111' as bigint) as boolean); ERROR: cannot cast type bigint to boolean ``` smallint: ``` postgres=# select cast(cast(2019 as smallint) as boolean); ERROR: cannot cast type smallint to boolean ``` bytea: ``` postgres=# select cast(cast('2019' as bytea) as boolean); ERROR: cannot cast type bytea to boolean ``` decimal: ``` postgres=# select cast(cast('2019' as decimal) as boolean); ERROR: cannot cast type numeric to boolean ``` float: ``` postgres=# select cast(cast('2019' as float) as boolean); ERROR: cannot cast type double precision to boolean ``` ### How was this patch tested? Added and tested manually. Closes #26463 from Ngone51/dev-postgre-cast2bool. Authored-by: wuyi <ngone_5451@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 39596b9 commit fe1f456

File tree

7 files changed

+175
-146
lines changed

7 files changed

+175
-146
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.internal.Logging
2121
import org.apache.spark.sql.catalyst.expressions.Cast
22-
import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastStringToBoolean
22+
import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastToBoolean
2323
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2424
import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.types.{BooleanType, StringType}
2727

2828
object PostgreSQLDialect {
2929
val postgreSQLDialectRules: List[Rule[LogicalPlan]] =
30-
CastStringToBoolean ::
30+
CastToBoolean ::
3131
Nil
3232

33-
object CastStringToBoolean extends Rule[LogicalPlan] with Logging {
33+
object CastToBoolean extends Rule[LogicalPlan] with Logging {
3434
override def apply(plan: LogicalPlan): LogicalPlan = {
3535
// The SQL configuration `spark.sql.dialect` can be changed in runtime.
3636
// To make sure the configuration is effective, we have to check it during rule execution.
3737
val conf = SQLConf.get
3838
if (conf.usePostgreSQLDialect) {
3939
plan.transformExpressions {
40-
case Cast(child, dataType, _)
41-
if dataType == BooleanType && child.dataType == StringType =>
42-
PostgreCastStringToBoolean(child)
40+
case Cast(child, dataType, timeZoneId)
41+
if child.dataType != BooleanType && dataType == BooleanType =>
42+
PostgreCastToBoolean(child, timeZoneId)
4343
}
4444
} else {
4545
plan

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
274274
private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
275275

276276
// [[func]] assumes the input is no longer null because eval already does the null check.
277-
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
277+
@inline protected def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
278278

279279
private lazy val dateFormatter = DateFormatter(zoneId)
280280
private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
@@ -377,7 +377,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
377377
}
378378

379379
// UDFToBoolean
380-
private[this] def castToBoolean(from: DataType): Any => Any = from match {
380+
protected[this] def castToBoolean(from: DataType): Any => Any = from match {
381381
case StringType =>
382382
buildCast[UTF8String](_, s => {
383383
if (StringUtils.isTrueString(s)) {
@@ -782,7 +782,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
782782
}
783783
}
784784

785-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
785+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
786786
val eval = child.genCode(ctx)
787787
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
788788

@@ -792,7 +792,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
792792

793793
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
794794
// in parameter list, because the returned code will be put in null safe evaluation region.
795-
private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block
795+
protected type CastFunction = (ExprValue, ExprValue, ExprValue) => Block
796796

797797
private[this] def nullSafeCastFunction(
798798
from: DataType,
@@ -1234,7 +1234,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
12341234
private[this] def timestampToDoubleCode(ts: ExprValue): Block =
12351235
code"$ts / (double)$MICROS_PER_SECOND"
12361236

1237-
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
1237+
protected[this] def castToBooleanCode(from: DataType): CastFunction = from match {
12381238
case StringType =>
12391239
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
12401240
(c, evPrim, evNull) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala

Lines changed: 0 additions & 80 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.expressions.postgreSQL
18+
19+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.expressions.{CastBase, Expression, TimeZoneAwareExpression}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
22+
import org.apache.spark.sql.catalyst.util.postgreSQL.StringUtils
23+
import org.apache.spark.sql.types._
24+
import org.apache.spark.unsafe.types.UTF8String
25+
26+
case class PostgreCastToBoolean(child: Expression, timeZoneId: Option[String])
27+
extends CastBase {
28+
29+
override protected def ansiEnabled =
30+
throw new UnsupportedOperationException("PostgreSQL dialect doesn't support ansi mode")
31+
32+
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
33+
copy(timeZoneId = Option(timeZoneId))
34+
35+
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
36+
case StringType | IntegerType | NullType =>
37+
TypeCheckResult.TypeCheckSuccess
38+
case _ =>
39+
TypeCheckResult.TypeCheckFailure(s"cannot cast type ${child.dataType} to boolean")
40+
}
41+
42+
override def castToBoolean(from: DataType): Any => Any = from match {
43+
case StringType =>
44+
buildCast[UTF8String](_, str => {
45+
val s = str.trim().toLowerCase()
46+
if (StringUtils.isTrueString(s)) {
47+
true
48+
} else if (StringUtils.isFalseString(s)) {
49+
false
50+
} else {
51+
throw new IllegalArgumentException(s"invalid input syntax for type boolean: $s")
52+
}
53+
})
54+
case IntegerType =>
55+
super.castToBoolean(from)
56+
}
57+
58+
override def castToBooleanCode(from: DataType): CastFunction = from match {
59+
case StringType =>
60+
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
61+
(c, evPrim, evNull) =>
62+
code"""
63+
if ($stringUtils.isTrueString($c.trim().toLowerCase())) {
64+
$evPrim = true;
65+
} else if ($stringUtils.isFalseString($c.trim().toLowerCase())) {
66+
$evPrim = false;
67+
} else {
68+
throw new IllegalArgumentException("invalid input syntax for type boolean: $c");
69+
}
70+
"""
71+
72+
case IntegerType =>
73+
super.castToBooleanCode(from)
74+
}
75+
76+
override def dataType: DataType = BooleanType
77+
78+
override def nullable: Boolean = child.nullable
79+
80+
override def toString: String = s"PostgreCastToBoolean($child as ${dataType.simpleString})"
81+
82+
override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})"
83+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,44 +16,58 @@
1616
*/
1717
package org.apache.spark.sql.catalyst.expressions.postgreSQL
1818

19+
import java.sql.{Date, Timestamp}
20+
1921
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.sql.AnalysisException
2023
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal}
2124

2225
class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
23-
private def checkPostgreCastStringToBoolean(v: Any, expected: Any): Unit = {
24-
checkEvaluation(PostgreCastStringToBoolean(Literal(v)), expected)
26+
private def checkPostgreCastToBoolean(v: Any, expected: Any): Unit = {
27+
checkEvaluation(PostgreCastToBoolean(Literal(v), None), expected)
2528
}
2629

2730
test("cast string to boolean") {
28-
checkPostgreCastStringToBoolean("true", true)
29-
checkPostgreCastStringToBoolean("tru", true)
30-
checkPostgreCastStringToBoolean("tr", true)
31-
checkPostgreCastStringToBoolean("t", true)
32-
checkPostgreCastStringToBoolean("tRUe", true)
33-
checkPostgreCastStringToBoolean(" tRue ", true)
34-
checkPostgreCastStringToBoolean(" tRu ", true)
35-
checkPostgreCastStringToBoolean("yes", true)
36-
checkPostgreCastStringToBoolean("ye", true)
37-
checkPostgreCastStringToBoolean("y", true)
38-
checkPostgreCastStringToBoolean("1", true)
39-
checkPostgreCastStringToBoolean("on", true)
31+
checkPostgreCastToBoolean("true", true)
32+
checkPostgreCastToBoolean("tru", true)
33+
checkPostgreCastToBoolean("tr", true)
34+
checkPostgreCastToBoolean("t", true)
35+
checkPostgreCastToBoolean("tRUe", true)
36+
checkPostgreCastToBoolean(" tRue ", true)
37+
checkPostgreCastToBoolean(" tRu ", true)
38+
checkPostgreCastToBoolean("yes", true)
39+
checkPostgreCastToBoolean("ye", true)
40+
checkPostgreCastToBoolean("y", true)
41+
checkPostgreCastToBoolean("1", true)
42+
checkPostgreCastToBoolean("on", true)
43+
44+
checkPostgreCastToBoolean("false", false)
45+
checkPostgreCastToBoolean("fals", false)
46+
checkPostgreCastToBoolean("fal", false)
47+
checkPostgreCastToBoolean("fa", false)
48+
checkPostgreCastToBoolean("f", false)
49+
checkPostgreCastToBoolean(" fAlse ", false)
50+
checkPostgreCastToBoolean(" fAls ", false)
51+
checkPostgreCastToBoolean(" FAlsE ", false)
52+
checkPostgreCastToBoolean("no", false)
53+
checkPostgreCastToBoolean("n", false)
54+
checkPostgreCastToBoolean("0", false)
55+
checkPostgreCastToBoolean("off", false)
56+
checkPostgreCastToBoolean("of", false)
4057

41-
checkPostgreCastStringToBoolean("false", false)
42-
checkPostgreCastStringToBoolean("fals", false)
43-
checkPostgreCastStringToBoolean("fal", false)
44-
checkPostgreCastStringToBoolean("fa", false)
45-
checkPostgreCastStringToBoolean("f", false)
46-
checkPostgreCastStringToBoolean(" fAlse ", false)
47-
checkPostgreCastStringToBoolean(" fAls ", false)
48-
checkPostgreCastStringToBoolean(" FAlsE ", false)
49-
checkPostgreCastStringToBoolean("no", false)
50-
checkPostgreCastStringToBoolean("n", false)
51-
checkPostgreCastStringToBoolean("0", false)
52-
checkPostgreCastStringToBoolean("off", false)
53-
checkPostgreCastStringToBoolean("of", false)
58+
intercept[IllegalArgumentException](PostgreCastToBoolean(Literal("o"), None).eval())
59+
intercept[IllegalArgumentException](PostgreCastToBoolean(Literal("abc"), None).eval())
60+
intercept[IllegalArgumentException](PostgreCastToBoolean(Literal(""), None).eval())
61+
}
5462

55-
checkPostgreCastStringToBoolean("o", null)
56-
checkPostgreCastStringToBoolean("abc", null)
57-
checkPostgreCastStringToBoolean("", null)
63+
test("unsupported data types to cast to boolean") {
64+
assert(PostgreCastToBoolean(Literal(new Timestamp(1)), None).checkInputDataTypes().isFailure)
65+
assert(PostgreCastToBoolean(Literal(new Date(1)), None).checkInputDataTypes().isFailure)
66+
assert(PostgreCastToBoolean(Literal(1.toLong), None).checkInputDataTypes().isFailure)
67+
assert(PostgreCastToBoolean(Literal(1.toShort), None).checkInputDataTypes().isFailure)
68+
assert(PostgreCastToBoolean(Literal(1.toByte), None).checkInputDataTypes().isFailure)
69+
assert(PostgreCastToBoolean(Literal(BigDecimal(1.0)), None).checkInputDataTypes().isFailure)
70+
assert(PostgreCastToBoolean(Literal(1.toDouble), None).checkInputDataTypes().isFailure)
71+
assert(PostgreCastToBoolean(Literal(1.toFloat), None).checkInputDataTypes().isFailure)
5872
}
5973
}

0 commit comments

Comments
 (0)