Skip to content

Commit 094d58e

Browse files
committed
new package and expression
1 parent 0e7f281 commit 094d58e

File tree

10 files changed

+260
-77
lines changed

10 files changed

+260
-77
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class Analyzer(
227227
ResolveRandomSeed ::
228228
TypeCoercion.typeCoercionRules(conf) ++
229229
extendedResolutionRules : _*),
230+
Batch("PostgreSQl dialect", Once, PostgreSQLDialect.postgreSQLDialectRules(conf): _*),
230231
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
231232
Batch("Nondeterministic", Once,
232233
PullOutNondeterministic),
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.internal.Logging
21+
import org.apache.spark.sql.catalyst.expressions.Cast
22+
import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastStringToBoolean
23+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.internal.SQLConf
26+
import org.apache.spark.sql.types.{BooleanType, StringType}
27+
28+
object PostgreSQLDialect {
29+
def postgreSQLDialectRules(conf: SQLConf): List[Rule[LogicalPlan]] =
30+
if (conf.usePostgreSQLDialect) {
31+
postgreCastStringToBoolean(conf) ::
32+
Nil
33+
} else {
34+
Nil
35+
}
36+
37+
case class postgreCastStringToBoolean(conf: SQLConf) extends Rule[LogicalPlan] with Logging {
38+
override def apply(plan: LogicalPlan): LogicalPlan = {
39+
plan.transformExpressions {
40+
case Cast(child, dataType, _) if dataType == BooleanType && child.dataType == StringType =>
41+
PostgreCastStringToBoolean(child)
42+
}
43+
}
44+
}
45+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,8 +677,7 @@ object TypeCoercion {
677677
case d: Divide if d.dataType == DoubleType => d
678678
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
679679
case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
680-
val preferIntegralDivision =
681-
conf.getConf(SQLConf.DIALECT) == SQLConf.Dialect.POSTGRESQL.toString
680+
val preferIntegralDivision = conf.usePostgreSQLDialect
682681
(left.dataType, right.dataType) match {
683682
case (_: IntegralType, _: IntegralType) if preferIntegralDivision =>
684683
IntegralDivide(left, right)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,11 +391,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
391391
// UDFToBoolean
392392
private[this] def castToBoolean(from: DataType): Any => Any = from match {
393393
case StringType =>
394-
val dialect = SQLConf.get.getConf(SQLConf.DIALECT)
395394
buildCast[UTF8String](_, s => {
396-
if (StringUtils.isTrueString(s, dialect)) {
395+
if (StringUtils.isTrueString(s)) {
397396
true
398-
} else if (StringUtils.isFalseString(s, dialect)) {
397+
} else if (StringUtils.isFalseString(s)) {
399398
false
400399
} else {
401400
null
@@ -1251,12 +1250,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
12511250
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
12521251
case StringType =>
12531252
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
1254-
val dialect = SQLConf.get.getConf(SQLConf.DIALECT)
12551253
(c, evPrim, evNull) =>
12561254
code"""
1257-
if ($stringUtils.isTrueString($c, "$dialect")) {
1255+
if ($stringUtils.isTrueString($c)) {
12581256
$evPrim = true;
1259-
} else if ($stringUtils.isFalseString($c, "$dialect")) {
1257+
} else if ($stringUtils.isFalseString($c)) {
12601258
$evPrim = false;
12611259
} else {
12621260
$evNull = true;
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.expressions.postgreSQL
18+
19+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, JavaCode}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
23+
import org.apache.spark.sql.catalyst.util.StringUtils
24+
import org.apache.spark.sql.types.{BooleanType, DataType, StringType}
25+
import org.apache.spark.unsafe.types.UTF8String
26+
27+
case class PostgreCastStringToBoolean(child: Expression)
28+
extends UnaryExpression {
29+
30+
override def checkInputDataTypes(): TypeCheckResult = {
31+
if (child.dataType == StringType) {
32+
TypeCheckResult.TypeCheckSuccess
33+
} else {
34+
TypeCheckResult.TypeCheckFailure(
35+
s"The expression ${getClass.getSimpleName} only accepts string input data type")
36+
}
37+
}
38+
39+
override def nullSafeEval(input: Any): Any = {
40+
val s = input.asInstanceOf[UTF8String]
41+
if (StringUtils.postgreIsTrueString(s)) {
42+
true
43+
} else if (StringUtils.postgreIsFalseString(s)) {
44+
false
45+
} else {
46+
null
47+
}
48+
}
49+
50+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
51+
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
52+
val eval = child.genCode(ctx)
53+
val javaType = JavaCode.javaType(dataType)
54+
val castCode =
55+
code"""
56+
boolean ${ev.isNull} = ${eval.isNull};
57+
$javaType ${ev.value} = false;
58+
if (!${eval.isNull}) {
59+
if ($stringUtils.postgreIsTrueString(${eval.value})) {
60+
${ev.value} = true;
61+
} else if ($stringUtils.postgreIsFalseString(${eval.value})) {
62+
${ev.value} = false;
63+
} else {
64+
${ev.isNull} = true;
65+
}
66+
}
67+
"""
68+
ev.copy(code = eval.code + castCode)
69+
}
70+
71+
override def dataType: DataType = BooleanType
72+
73+
override def nullable: Boolean = true
74+
75+
override def toString: String = s"postgreCastStringToBoolean($child as ${dataType.simpleString})"
76+
77+
override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})"
78+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,17 @@ object StringUtils extends Logging {
7575
Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
7676
private[this] val falseStringsOfPostgreSQL =
7777
Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString)
78+
7879
// scalastyle:off caselocale
79-
def isTrueString(s: UTF8String, dialect: String): Boolean = {
80-
SQLConf.Dialect.withName(dialect) match {
81-
case SQLConf.Dialect.SPARK =>
82-
trueStrings.contains(s.toLowerCase)
83-
case SQLConf.Dialect.POSTGRESQL =>
84-
trueStringsOfPostgreSQL.contains(s.toLowerCase.trim())
85-
}
86-
}
80+
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
8781

88-
def isFalseString(s: UTF8String, dialect: String): Boolean = {
89-
SQLConf.Dialect.withName(dialect) match {
90-
case SQLConf.Dialect.SPARK =>
91-
falseStrings.contains(s.toLowerCase)
92-
case SQLConf.Dialect.POSTGRESQL =>
93-
falseStringsOfPostgreSQL.contains(s.toLowerCase.trim())
94-
}
95-
}
82+
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
83+
84+
def postgreIsTrueString(s: UTF8String): Boolean =
85+
trueStringsOfPostgreSQL.contains(s.toLowerCase.trim())
86+
87+
def postgreIsFalseString(s: UTF8String): Boolean =
88+
falseStringsOfPostgreSQL.contains(s.toLowerCase.trim())
9689
// scalastyle:on caselocale
9790

9891
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,6 +2441,8 @@ class SQLConf extends Serializable with Logging {
24412441

24422442
def ansiEnabled: Boolean = getConf(ANSI_ENABLED)
24432443

2444+
def usePostgreSQLDialect: Boolean = getConf(DIALECT) == Dialect.POSTGRESQL.toString()
2445+
24442446
def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)
24452447

24462448
def serializerNestedSchemaPruningEnabled: Boolean =

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -819,59 +819,24 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
819819
}
820820

821821
test("cast string to boolean with Spark dialect") {
822-
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.SPARK.toString) {
823-
checkCast("t", true)
824-
checkCast("true", true)
825-
checkCast("tRUe", true)
826-
checkCast("y", true)
827-
checkCast("yes", true)
828-
checkCast("1", true)
829-
830-
checkCast("f", false)
831-
checkCast("false", false)
832-
checkCast("FAlsE", false)
833-
checkCast("n", false)
834-
checkCast("no", false)
835-
checkCast("0", false)
836-
837-
checkEvaluation(cast("abc", BooleanType), null)
838-
checkEvaluation(cast("", BooleanType), null)
839-
}
840-
}
841-
842-
test("cast string to boolean with PostgreSQL dialect") {
843-
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.POSTGRESQL.toString) {
844-
checkCast("true", true)
845-
checkCast("tru", true)
846-
checkCast("tr", true)
847-
checkCast("t", true)
848-
checkCast("tRUe", true)
849-
checkCast(" tRue ", true)
850-
checkCast(" tRu ", true)
851-
checkCast("yes", true)
852-
checkCast("ye", true)
853-
checkCast("y", true)
854-
checkCast("1", true)
855-
checkCast("on", true)
856-
857-
checkCast("false", false)
858-
checkCast("fals", false)
859-
checkCast("fal", false)
860-
checkCast("fa", false)
861-
checkCast("f", false)
862-
checkCast(" fAlse ", false)
863-
checkCast(" fAls ", false)
864-
checkCast(" FAlsE ", false)
865-
checkCast("no", false)
866-
checkCast("n", false)
867-
checkCast("0", false)
868-
checkCast("off", false)
869-
checkCast("of", false)
870-
871-
checkEvaluation(cast("o", BooleanType), null)
872-
checkEvaluation(cast("abc", BooleanType), null)
873-
checkEvaluation(cast("", BooleanType), null)
874-
}
822+
checkCast("t", true)
823+
checkCast("true", true)
824+
checkCast("tRUe", true)
825+
checkCast("y", true)
826+
checkCast("yes", true)
827+
checkCast("1", true)
828+
829+
checkCast("f", false)
830+
checkCast("false", false)
831+
checkCast("FAlsE", false)
832+
checkCast("n", false)
833+
checkCast("no", false)
834+
checkCast("0", false)
835+
836+
checkEvaluation(cast("abc", BooleanType), null)
837+
checkEvaluation(cast("tru", BooleanType), null)
838+
checkEvaluation(cast("fla", BooleanType), null)
839+
checkEvaluation(cast("", BooleanType), null)
875840
}
876841

877842
test("SPARK-16729 type checking for casting to date type") {
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.expressions.postgreSQL
18+
19+
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal}
21+
22+
class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
23+
private def checkPostgreCastStringToBoolean(v: Any, expected: Any): Unit = {
24+
checkEvaluation(PostgreCastStringToBoolean(Literal(v)), expected)
25+
}
26+
27+
test("cast string to boolean with PostgreSQL dialect") {
28+
checkPostgreCastStringToBoolean("true", true)
29+
checkPostgreCastStringToBoolean("tru", true)
30+
checkPostgreCastStringToBoolean("tr", true)
31+
checkPostgreCastStringToBoolean("t", true)
32+
checkPostgreCastStringToBoolean("tRUe", true)
33+
checkPostgreCastStringToBoolean(" tRue ", true)
34+
checkPostgreCastStringToBoolean(" tRu ", true)
35+
checkPostgreCastStringToBoolean("yes", true)
36+
checkPostgreCastStringToBoolean("ye", true)
37+
checkPostgreCastStringToBoolean("y", true)
38+
checkPostgreCastStringToBoolean("1", true)
39+
checkPostgreCastStringToBoolean("on", true)
40+
41+
checkPostgreCastStringToBoolean("false", false)
42+
checkPostgreCastStringToBoolean("fals", false)
43+
checkPostgreCastStringToBoolean("fal", false)
44+
checkPostgreCastStringToBoolean("fa", false)
45+
checkPostgreCastStringToBoolean("f", false)
46+
checkPostgreCastStringToBoolean(" fAlse ", false)
47+
checkPostgreCastStringToBoolean(" fAls ", false)
48+
checkPostgreCastStringToBoolean(" FAlsE ", false)
49+
checkPostgreCastStringToBoolean("no", false)
50+
checkPostgreCastStringToBoolean("n", false)
51+
checkPostgreCastStringToBoolean("0", false)
52+
checkPostgreCastStringToBoolean("off", false)
53+
checkPostgreCastStringToBoolean("of", false)
54+
55+
checkPostgreCastStringToBoolean("o", null)
56+
checkPostgreCastStringToBoolean("abc", null)
57+
checkPostgreCastStringToBoolean("", null)
58+
}
59+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql
18+
19+
import org.apache.spark.SparkConf
20+
import org.apache.spark.sql.internal.SQLConf
21+
import org.apache.spark.sql.test.SharedSparkSession
22+
23+
class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession {
24+
25+
override def sparkConf: SparkConf =
26+
super.sparkConf
27+
.set(SQLConf.DIALECT.key, SQLConf.Dialect.POSTGRESQL.toString)
28+
29+
test("cast string to boolean") {
30+
Seq("true", "tru", "tr", "t", " tRue ", " tRu ", "yes", "ye",
31+
"y", "1", "on").foreach { input =>
32+
checkAnswer(sql(s"select cast('$input' as boolean)"), Row(true))
33+
}
34+
Seq("false", "fals", "fal", "fa", "f", " fAlse ", " fAls ", "no", "n",
35+
"0", "off", "of").foreach { input =>
36+
checkAnswer(sql(s"select cast('$input' as boolean)"), Row(false))
37+
}
38+
39+
Seq("o", "abc", "").foreach { input =>
40+
checkAnswer(sql(s"select cast('$input' as boolean)"), Row(null))
41+
}
42+
}
43+
}

0 commit comments

Comments
 (0)