Skip to content

Commit bf8a4c4

Browse files
belieferdongjoon-hyun
authored andcommitted
[SPARK-39961][SQL] DS V2 push-down translate Cast if the cast is safe
### What changes were proposed in this pull request? Currently, DS V2 push-down translate `Cast` only if the ansi mode is true. In fact, if the cast is safe(e.g. cast number to string, cast int to long), we can translate it too. This PR will call `Cast.canUpCast` so as we can translate `Cast` to V2 `Cast` safely. Note: The rule `SimplifyCasts` optimize some safe cast, e.g. cast int to long, so we may not see the `Cast`. ### Why are the changes needed? Add the range for DS V2 push down `Cast`. ### Does this PR introduce _any_ user-facing change? 'Yes'. `Cast` could be pushed down to data source in more cases. ### How was this patch tested? Test cases updated. Closes #37388 from beliefer/SPARK-39961. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 0de98fd commit bf8a4c4

File tree

2 files changed

+20
-28
lines changed

2 files changed

+20
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
8888
} else {
8989
None
9090
}
91-
case Cast(child, dataType, _, true) =>
91+
case Cast(child, dataType, _, ansiEnabled)
92+
if ansiEnabled || Cast.canUpCast(child.dataType, dataType) =>
9293
generateExpression(child).map(v => new V2Cast(v, dataType))
9394
case Abs(child, true) => generateExpressionWithName("ABS", Seq(child))
9495
case Coalesce(children) => generateExpressionWithName("COALESCE", children)

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
11091109
"CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, " +
11101110
"CAST(DEPT AS short) > 1, CAST(BONUS AS decimal(20,2)) > 1200.00]"
11111111
} else {
1112-
"PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL],"
1112+
"PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%']"
11131113
}
11141114
checkPushedInfo(df6, expectedPlanFragment6)
11151115
checkAnswer(df6, Seq(Row(2, "david", 10000, 1300, true)))
@@ -1199,18 +1199,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
11991199
checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],")
12001200
checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2)))
12011201

1202-
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
1203-
val df2 = sql(
1204-
"""
1205-
|SELECT *
1206-
|FROM h2.test.people
1207-
|WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2
1202+
val df2 = sql(
1203+
"""
1204+
|SELECT *
1205+
|FROM h2.test.people
1206+
|WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2
12081207
""".stripMargin)
1209-
checkFiltersRemoved(df2)
1210-
checkPushedInfo(df2,
1211-
"PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],")
1212-
checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
1213-
}
1208+
checkFiltersRemoved(df2)
1209+
checkPushedInfo(df2,
1210+
"PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],")
1211+
checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
12141212
} finally {
12151213
JdbcDialects.unregisterDialect(testH2Dialect)
12161214
JdbcDialects.registerDialect(H2Dialect)
@@ -2262,24 +2260,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
22622260
}
22632261

22642262
test("scan with aggregate push-down: partial push-down AVG with overflow") {
2265-
def createDataFrame: DataFrame = spark.read
2266-
.option("partitionColumn", "id")
2267-
.option("lowerBound", "0")
2268-
.option("upperBound", "2")
2269-
.option("numPartitions", "2")
2270-
.table("h2.test.item")
2271-
.agg(avg($"PRICE").as("avg"))
2272-
22732263
Seq(true, false).foreach { ansiEnabled =>
22742264
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
2275-
val df = createDataFrame
2265+
val df = spark.read
2266+
.option("partitionColumn", "id")
2267+
.option("lowerBound", "0")
2268+
.option("upperBound", "2")
2269+
.option("numPartitions", "2")
2270+
.table("h2.test.item")
2271+
.agg(avg($"PRICE").as("avg"))
22762272
checkAggregateRemoved(df, false)
2277-
df.queryExecution.optimizedPlan.collect {
2278-
case _: DataSourceV2ScanRelation =>
2279-
val expected_plan_fragment =
2280-
"PushedAggregates: [COUNT(PRICE), SUM(PRICE)]"
2281-
checkKeywordsExistsInExplain(df, expected_plan_fragment)
2282-
}
2273+
checkPushedInfo(df, "PushedAggregates: [COUNT(PRICE), SUM(PRICE)]")
22832274
if (ansiEnabled) {
22842275
val e = intercept[SparkException] {
22852276
df.collect()

0 commit comments

Comments
 (0)