Skip to content

Commit 4a0f0ff

Browse files
olakyMaxGekk
authored andcommitted
[SPARK-39259][SQL][3.3] Evaluate timestamps consistently in subqueries
### What changes were proposed in this pull request? Apply the optimizer rule ComputeCurrentTime consistently across subqueries. This is a backport of #36654. ### Why are the changes needed? At the moment timestamp functions like now() can return different values within a query if subqueries are involved ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new unit test was added Closes #36752 from olaky/SPARK-39259-spark_3_3. Authored-by: Ole Sasse <ole.sasse@databricks.com> Signed-off-by: Max Gekk <max.gekk@gmail.com>
1 parent 8f599ba commit 4a0f0ff

File tree

3 files changed

+95
-46
lines changed

3 files changed

+95
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import scala.collection.mutable
20+
import java.time.{Instant, LocalDateTime}
2121

2222
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules._
2626
import org.apache.spark.sql.catalyst.trees.TreePattern._
27-
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ}
27+
import org.apache.spark.sql.catalyst.trees.TreePatternBits
28+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
29+
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
2830
import org.apache.spark.sql.connector.catalog.CatalogManager
2931
import org.apache.spark.sql.types._
3032
import org.apache.spark.util.Utils
@@ -73,29 +75,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
7375
*/
7476
object ComputeCurrentTime extends Rule[LogicalPlan] {
7577
def apply(plan: LogicalPlan): LogicalPlan = {
76-
val currentDates = mutable.Map.empty[String, Literal]
77-
val timeExpr = CurrentTimestamp()
78-
val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long]
79-
val currentTime = Literal.create(timestamp, timeExpr.dataType)
78+
val instant = Instant.now()
79+
val currentTimestampMicros = instantToMicros(instant)
80+
val currentTime = Literal.create(currentTimestampMicros, TimestampType)
8081
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)
81-
val localTimestamps = mutable.Map.empty[String, Literal]
8282

83-
plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
84-
case currentDate @ CurrentDate(Some(timeZoneId)) =>
85-
currentDates.getOrElseUpdate(timeZoneId, {
86-
Literal.create(currentDate.eval().asInstanceOf[Int], DateType)
87-
})
88-
case CurrentTimestamp() | Now() => currentTime
89-
case CurrentTimeZone() => timezone
90-
case localTimestamp @ LocalTimestamp(Some(timeZoneId)) =>
91-
localTimestamps.getOrElseUpdate(timeZoneId, {
92-
Literal.create(localTimestamp.eval().asInstanceOf[Long], TimestampNTZType)
93-
})
83+
def transformCondition(treePatternbits: TreePatternBits): Boolean = {
84+
treePatternbits.containsPattern(CURRENT_LIKE)
85+
}
86+
87+
plan.transformDownWithSubqueries(transformCondition) {
88+
case subQuery =>
89+
subQuery.transformAllExpressionsWithPruning(transformCondition) {
90+
case cd: CurrentDate =>
91+
Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
92+
case CurrentTimestamp() | Now() => currentTime
93+
case CurrentTimeZone() => timezone
94+
case localTimestamp: LocalTimestamp =>
95+
val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
96+
Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
97+
}
9498
}
9599
}
96100
}
97101

98-
99102
/**
100103
* Replaces the expression of CurrentDatabase with the current database name.
101104
* Replaces the expression of CurrentCatalog with the current catalog name.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
454454
* to rewrite the whole plan, include its subqueries, in one go.
455455
*/
456456
def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType =
457-
transformDownWithSubqueries(f)
457+
transformDownWithSubqueries(AlwaysProcess.fn, UnknownRuleId)(f)
458458

459459
/**
460460
* Returns a copy of this node where the given partial function has been recursively applied
@@ -479,21 +479,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
479479
* first to this node, then this node's subqueries and finally this node's children.
480480
* When the partial function does not apply to a given node, it is left unchanged.
481481
*/
482-
def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
482+
def transformDownWithSubqueries(
483+
cond: TreePatternBits => Boolean = AlwaysProcess.fn, ruleId: RuleId = UnknownRuleId)
484+
(f: PartialFunction[PlanType, PlanType])
485+
: PlanType = {
483486
val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] {
484487
override def isDefinedAt(x: PlanType): Boolean = true
485488

486489
override def apply(plan: PlanType): PlanType = {
487490
val transformed = f.applyOrElse[PlanType, PlanType](plan, identity)
488491
transformed transformExpressionsDown {
489492
case planExpression: PlanExpression[PlanType] =>
490-
val newPlan = planExpression.plan.transformDownWithSubqueries(f)
493+
val newPlan = planExpression.plan.transformDownWithSubqueries(cond, ruleId)(f)
491494
planExpression.withNewPlan(newPlan)
492495
}
493496
}
494497
}
495498

496-
transformDown(g)
499+
transformDownWithPruning(cond, ruleId)(g)
497500
}
498501

499502
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import java.time.{LocalDateTime, ZoneId}
2121

22+
import scala.collection.JavaConverters.mapAsScalaMap
23+
import scala.concurrent.duration._
24+
2225
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal, LocalTimestamp}
26+
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
2427
import org.apache.spark.sql.catalyst.plans.PlanTest
25-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
28+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
2629
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2730
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2831
import org.apache.spark.sql.internal.SQLConf
@@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
4144
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
4245
val max = (System.currentTimeMillis() + 1) * 1000
4346

44-
val lits = new scala.collection.mutable.ArrayBuffer[Long]
45-
plan.transformAllExpressions { case e: Literal =>
46-
lits += e.value.asInstanceOf[Long]
47-
e
48-
}
47+
val lits = literals[Long](plan)
4948
assert(lits.size == 2)
5049
assert(lits(0) >= min && lits(0) <= max)
5150
assert(lits(1) >= min && lits(1) <= max)
@@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
5958
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
6059
val max = DateTimeUtils.currentDate(ZoneId.systemDefault())
6160

62-
val lits = new scala.collection.mutable.ArrayBuffer[Int]
63-
plan.transformAllExpressions { case e: Literal =>
64-
lits += e.value.asInstanceOf[Int]
65-
e
66-
}
61+
val lits = literals[Int](plan)
6762
assert(lits.size == 2)
6863
assert(lits(0) >= min && lits(0) <= max)
6964
assert(lits(1) >= min && lits(1) <= max)
@@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
7368
test("SPARK-33469: Add current_timezone function") {
7469
val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation())
7570
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
76-
val lits = new scala.collection.mutable.ArrayBuffer[String]
77-
plan.transformAllExpressions { case e: Literal =>
78-
lits += e.value.asInstanceOf[UTF8String].toString
79-
e
80-
}
71+
val lits = literals[UTF8String](plan)
8172
assert(lits.size == 1)
82-
assert(lits.head == SQLConf.get.sessionLocalTimeZone)
73+
assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone))
8374
}
8475

8576
test("analyzer should replace localtimestamp with literals") {
@@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest {
9283
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
9384
val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))
9485

95-
val lits = new scala.collection.mutable.ArrayBuffer[Long]
96-
plan.transformAllExpressions { case e: Literal =>
97-
lits += e.value.asInstanceOf[Long]
98-
e
99-
}
86+
val lits = literals[Long](plan)
10087
assert(lits.size == 2)
10188
assert(lits(0) >= min && lits(0) <= max)
10289
assert(lits(1) >= min && lits(1) <= max)
10390
assert(lits(0) == lits(1))
10491
}
92+
93+
test("analyzer should use equal timestamps across subqueries") {
94+
val timestampInSubQuery = Project(Seq(Alias(LocalTimestamp(), "timestamp1")()), LocalRelation())
95+
val listSubQuery = ListQuery(timestampInSubQuery)
96+
val valueSearchedInSubQuery = Seq(Alias(LocalTimestamp(), "timestamp2")())
97+
val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery)
98+
val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation()))
99+
100+
val plan = Optimize.execute(input.analyze).asInstanceOf[Project]
101+
102+
val lits = literals[Long](plan)
103+
assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice
104+
assert(lits.toSet.size == 1)
105+
}
106+
107+
test("analyzer should use consistent timestamps for different timezones") {
108+
val localTimestamps = mapAsScalaMap(ZoneId.SHORT_IDS)
109+
.map { case (zoneId, _) => Alias(LocalTimestamp(Some(zoneId)), zoneId)() }.toSeq
110+
val input = Project(localTimestamps, LocalRelation())
111+
112+
val plan = Optimize.execute(input).asInstanceOf[Project]
113+
114+
val lits = literals[Long](plan)
115+
assert(lits.size === localTimestamps.size)
116+
// there are timezones with a 30 or 45 minute offset
117+
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
118+
assert(offsetsFromQuarterHour.size == 1)
119+
}
120+
121+
test("analyzer should use consistent timestamps for different timestamp functions") {
122+
val differentTimestamps = Seq(
123+
Alias(CurrentTimestamp(), "currentTimestamp")(),
124+
Alias(Now(), "now")(),
125+
Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")()
126+
)
127+
val input = Project(differentTimestamps, LocalRelation())
128+
129+
val plan = Optimize.execute(input).asInstanceOf[Project]
130+
131+
val lits = literals[Long](plan)
132+
assert(lits.size === differentTimestamps.size)
133+
// there are timezones with a 30 or 45 minute offset
134+
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
135+
assert(offsetsFromQuarterHour.size == 1)
136+
}
137+
138+
private def literals[T](plan: LogicalPlan): Seq[T] = {
139+
val literals = new scala.collection.mutable.ArrayBuffer[T]
140+
plan.transformWithSubqueries { case subQuery =>
141+
subQuery.transformAllExpressions { case expression: Literal =>
142+
literals += expression.value.asInstanceOf[T]
143+
expression
144+
}
145+
}
146+
literals.asInstanceOf[Seq[T]]
147+
}
105148
}

0 commit comments

Comments
 (0)