Skip to content

Commit d611d1f

Browse files
olakyJoshRosendongjoon-hyun
authored andcommitted
[SPARK-39259][SQL][3.2] Evaluate timestamps consistently in subqueries
### What changes were proposed in this pull request? Apply the optimizer rule ComputeCurrentTime consistently across subqueries. This is a backport of #36654. ### Why are the changes needed? At the moment timestamp functions like now() can return different values within a query if subqueries are involved ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new unit test was added Closes #36753 from olaky/SPARK-39259-spark_3_2. Lead-authored-by: Ole Sasse <ole.sasse@databricks.com> Co-authored-by: Josh Rosen <joshrosen@databricks.com> Co-authored-by: Dongjoon Hyun <dongjoon@apache.org> Signed-off-by: Max Gekk <max.gekk@gmail.com>
1 parent d9477dd commit d611d1f

File tree

3 files changed

+103
-44
lines changed

3 files changed

+103
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import scala.collection.mutable
20+
import java.time.{Instant, LocalDateTime}
2121

2222
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate._
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
2727
import org.apache.spark.sql.catalyst.trees.TreePattern._
28-
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ}
28+
import org.apache.spark.sql.catalyst.trees.TreePatternBits
29+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
30+
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
2931
import org.apache.spark.sql.connector.catalog.CatalogManager
3032
import org.apache.spark.sql.types._
3133
import org.apache.spark.util.Utils
@@ -76,29 +78,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
7678
*/
7779
object ComputeCurrentTime extends Rule[LogicalPlan] {
7880
def apply(plan: LogicalPlan): LogicalPlan = {
79-
val currentDates = mutable.Map.empty[String, Literal]
80-
val timeExpr = CurrentTimestamp()
81-
val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long]
82-
val currentTime = Literal.create(timestamp, timeExpr.dataType)
81+
val instant = Instant.now()
82+
val currentTimestampMicros = instantToMicros(instant)
83+
val currentTime = Literal.create(currentTimestampMicros, TimestampType)
8384
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)
84-
val localTimestamps = mutable.Map.empty[String, Literal]
8585

86-
plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
87-
case currentDate @ CurrentDate(Some(timeZoneId)) =>
88-
currentDates.getOrElseUpdate(timeZoneId, {
89-
Literal.create(currentDate.eval().asInstanceOf[Int], DateType)
90-
})
91-
case CurrentTimestamp() | Now() => currentTime
92-
case CurrentTimeZone() => timezone
93-
case localTimestamp @ LocalTimestamp(Some(timeZoneId)) =>
94-
localTimestamps.getOrElseUpdate(timeZoneId, {
95-
Literal.create(localTimestamp.eval().asInstanceOf[Long], TimestampNTZType)
96-
})
86+
def transformCondition(treePatternbits: TreePatternBits): Boolean = {
87+
treePatternbits.containsPattern(CURRENT_LIKE)
88+
}
89+
90+
plan.transformDownWithSubqueriesAndPruning(transformCondition) {
91+
case subQuery =>
92+
subQuery.transformAllExpressionsWithPruning(transformCondition) {
93+
case cd: CurrentDate =>
94+
Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
95+
case CurrentTimestamp() | Now() => currentTime
96+
case CurrentTimeZone() => timezone
97+
case localTimestamp: LocalTimestamp =>
98+
val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
99+
Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
100+
}
97101
}
98102
}
99103
}
100104

101-
102105
/**
103106
* Replaces the expression of CurrentDatabase with the current database name.
104107
* Replaces the expression of CurrentCatalog with the current catalog name.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,20 +473,33 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
473473
* When the partial function does not apply to a given node, it is left unchanged.
474474
*/
475475
def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
476+
transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f)
477+
}
478+
479+
/**
480+
* This method is the top-down (pre-order) counterpart of transformUpWithSubqueries.
481+
* Returns a copy of this node where the given partial function has been recursively applied
482+
* first to this node, then this node's subqueries and finally this node's children.
483+
* When the partial function does not apply to a given node, it is left unchanged.
484+
*/
485+
def transformDownWithSubqueriesAndPruning(
486+
cond: TreePatternBits => Boolean,
487+
ruleId: RuleId = UnknownRuleId)
488+
(f: PartialFunction[PlanType, PlanType]): PlanType = {
476489
val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] {
477490
override def isDefinedAt(x: PlanType): Boolean = true
478491

479492
override def apply(plan: PlanType): PlanType = {
480493
val transformed = f.applyOrElse[PlanType, PlanType](plan, identity)
481494
transformed transformExpressionsDown {
482495
case planExpression: PlanExpression[PlanType] =>
483-
val newPlan = planExpression.plan.transformDownWithSubqueries(f)
496+
val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f)
484497
planExpression.withNewPlan(newPlan)
485498
}
486499
}
487500
}
488501

489-
transformDown(g)
502+
transformDownWithPruning(cond, ruleId)(g)
490503
}
491504

492505
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import java.time.{LocalDateTime, ZoneId}
2121

22+
import scala.collection.JavaConverters.mapAsScalaMap
23+
import scala.concurrent.duration._
24+
2225
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal, LocalTimestamp}
26+
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
2427
import org.apache.spark.sql.catalyst.plans.PlanTest
25-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
28+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
2629
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2730
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2831
import org.apache.spark.sql.internal.SQLConf
@@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
4144
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
4245
val max = (System.currentTimeMillis() + 1) * 1000
4346

44-
val lits = new scala.collection.mutable.ArrayBuffer[Long]
45-
plan.transformAllExpressions { case e: Literal =>
46-
lits += e.value.asInstanceOf[Long]
47-
e
48-
}
47+
val lits = literals[Long](plan)
4948
assert(lits.size == 2)
5049
assert(lits(0) >= min && lits(0) <= max)
5150
assert(lits(1) >= min && lits(1) <= max)
@@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
5958
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
6059
val max = DateTimeUtils.currentDate(ZoneId.systemDefault())
6160

62-
val lits = new scala.collection.mutable.ArrayBuffer[Int]
63-
plan.transformAllExpressions { case e: Literal =>
64-
lits += e.value.asInstanceOf[Int]
65-
e
66-
}
61+
val lits = literals[Int](plan)
6762
assert(lits.size == 2)
6863
assert(lits(0) >= min && lits(0) <= max)
6964
assert(lits(1) >= min && lits(1) <= max)
@@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
7368
test("SPARK-33469: Add current_timezone function") {
7469
val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation())
7570
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
76-
val lits = new scala.collection.mutable.ArrayBuffer[String]
77-
plan.transformAllExpressions { case e: Literal =>
78-
lits += e.value.asInstanceOf[UTF8String].toString
79-
e
80-
}
71+
val lits = literals[UTF8String](plan)
8172
assert(lits.size == 1)
82-
assert(lits.head == SQLConf.get.sessionLocalTimeZone)
73+
assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone))
8374
}
8475

8576
test("analyzer should replace localtimestamp with literals") {
@@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest {
9283
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
9384
val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))
9485

95-
val lits = new scala.collection.mutable.ArrayBuffer[Long]
96-
plan.transformAllExpressions { case e: Literal =>
97-
lits += e.value.asInstanceOf[Long]
98-
e
99-
}
86+
val lits = literals[Long](plan)
10087
assert(lits.size == 2)
10188
assert(lits(0) >= min && lits(0) <= max)
10289
assert(lits(1) >= min && lits(1) <= max)
10390
assert(lits(0) == lits(1))
10491
}
92+
93+
test("analyzer should use equal timestamps across subqueries") {
94+
val timestampInSubQuery = Project(Seq(Alias(LocalTimestamp(), "timestamp1")()), LocalRelation())
95+
val listSubQuery = ListQuery(timestampInSubQuery)
96+
val valueSearchedInSubQuery = Seq(Alias(LocalTimestamp(), "timestamp2")())
97+
val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery)
98+
val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation()))
99+
100+
val plan = Optimize.execute(input.analyze).asInstanceOf[Project]
101+
102+
val lits = literals[Long](plan)
103+
assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice
104+
assert(lits.toSet.size == 1)
105+
}
106+
107+
test("analyzer should use consistent timestamps for different timezones") {
108+
val localTimestamps = mapAsScalaMap(ZoneId.SHORT_IDS)
109+
.map { case (zoneId, _) => Alias(LocalTimestamp(Some(zoneId)), zoneId)() }.toSeq
110+
val input = Project(localTimestamps, LocalRelation())
111+
112+
val plan = Optimize.execute(input).asInstanceOf[Project]
113+
114+
val lits = literals[Long](plan)
115+
assert(lits.size === localTimestamps.size)
116+
// there are timezones with a 30 or 45 minute offset
117+
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
118+
assert(offsetsFromQuarterHour.size == 1)
119+
}
120+
121+
test("analyzer should use consistent timestamps for different timestamp functions") {
122+
val differentTimestamps = Seq(
123+
Alias(CurrentTimestamp(), "currentTimestamp")(),
124+
Alias(Now(), "now")(),
125+
Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")()
126+
)
127+
val input = Project(differentTimestamps, LocalRelation())
128+
129+
val plan = Optimize.execute(input).asInstanceOf[Project]
130+
131+
val lits = literals[Long](plan)
132+
assert(lits.size === differentTimestamps.size)
133+
// there are timezones with a 30 or 45 minute offset
134+
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
135+
assert(offsetsFromQuarterHour.size == 1)
136+
}
137+
138+
private def literals[T](plan: LogicalPlan): scala.collection.mutable.ArrayBuffer[T] = {
139+
val literals = new scala.collection.mutable.ArrayBuffer[T]
140+
plan.transformWithSubqueries { case subQuery =>
141+
subQuery.transformAllExpressions { case expression: Literal =>
142+
literals += expression.value.asInstanceOf[T]
143+
expression
144+
}
145+
}
146+
literals
147+
}
105148
}

0 commit comments

Comments
 (0)