@@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
19
19
20
20
import java .time .{LocalDateTime , ZoneId }
21
21
22
+ import scala .collection .JavaConverters .mapAsScalaMap
23
+ import scala .concurrent .duration ._
24
+
22
25
import org .apache .spark .sql .catalyst .dsl .plans ._
23
- import org .apache .spark .sql .catalyst .expressions .{Alias , CurrentDate , CurrentTimestamp , CurrentTimeZone , Literal , LocalTimestamp }
26
+ import org .apache .spark .sql .catalyst .expressions .{Alias , CurrentDate , CurrentTimestamp , CurrentTimeZone , InSubquery , ListQuery , Literal , LocalTimestamp , Now }
24
27
import org .apache .spark .sql .catalyst .plans .PlanTest
25
- import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , LogicalPlan , Project }
28
+ import org .apache .spark .sql .catalyst .plans .logical .{Filter , LocalRelation , LogicalPlan , Project }
26
29
import org .apache .spark .sql .catalyst .rules .RuleExecutor
27
30
import org .apache .spark .sql .catalyst .util .DateTimeUtils
28
31
import org .apache .spark .sql .internal .SQLConf
@@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
41
44
val plan = Optimize .execute(in.analyze).asInstanceOf [Project ]
42
45
val max = (System .currentTimeMillis() + 1 ) * 1000
43
46
44
- val lits = new scala.collection.mutable.ArrayBuffer [Long ]
45
- plan.transformAllExpressions { case e : Literal =>
46
- lits += e.value.asInstanceOf [Long ]
47
- e
48
- }
47
+ val lits = literals[Long ](plan)
49
48
assert(lits.size == 2 )
50
49
assert(lits(0 ) >= min && lits(0 ) <= max)
51
50
assert(lits(1 ) >= min && lits(1 ) <= max)
@@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
59
58
val plan = Optimize .execute(in.analyze).asInstanceOf [Project ]
60
59
val max = DateTimeUtils .currentDate(ZoneId .systemDefault())
61
60
62
- val lits = new scala.collection.mutable.ArrayBuffer [Int ]
63
- plan.transformAllExpressions { case e : Literal =>
64
- lits += e.value.asInstanceOf [Int ]
65
- e
66
- }
61
+ val lits = literals[Int ](plan)
67
62
assert(lits.size == 2 )
68
63
assert(lits(0 ) >= min && lits(0 ) <= max)
69
64
assert(lits(1 ) >= min && lits(1 ) <= max)
@@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
73
68
test(" SPARK-33469: Add current_timezone function" ) {
74
69
val in = Project (Seq (Alias (CurrentTimeZone (), " c" )()), LocalRelation ())
75
70
val plan = Optimize .execute(in.analyze).asInstanceOf [Project ]
76
- val lits = new scala.collection.mutable.ArrayBuffer [String ]
77
- plan.transformAllExpressions { case e : Literal =>
78
- lits += e.value.asInstanceOf [UTF8String ].toString
79
- e
80
- }
71
+ val lits = literals[UTF8String ](plan)
81
72
assert(lits.size == 1 )
82
- assert(lits.head == SQLConf .get.sessionLocalTimeZone)
73
+ assert(lits.head == UTF8String .fromString( SQLConf .get.sessionLocalTimeZone) )
83
74
}
84
75
85
76
test(" analyzer should replace localtimestamp with literals" ) {
@@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest {
92
83
val plan = Optimize .execute(in.analyze).asInstanceOf [Project ]
93
84
val max = DateTimeUtils .localDateTimeToMicros(LocalDateTime .now(zoneId))
94
85
95
- val lits = new scala.collection.mutable.ArrayBuffer [Long ]
96
- plan.transformAllExpressions { case e : Literal =>
97
- lits += e.value.asInstanceOf [Long ]
98
- e
99
- }
86
+ val lits = literals[Long ](plan)
100
87
assert(lits.size == 2 )
101
88
assert(lits(0 ) >= min && lits(0 ) <= max)
102
89
assert(lits(1 ) >= min && lits(1 ) <= max)
103
90
assert(lits(0 ) == lits(1 ))
104
91
}
92
+
93
+ test(" analyzer should use equal timestamps across subqueries" ) {
94
+ val timestampInSubQuery = Project (Seq (Alias (LocalTimestamp (), " timestamp1" )()), LocalRelation ())
95
+ val listSubQuery = ListQuery (timestampInSubQuery)
96
+ val valueSearchedInSubQuery = Seq (Alias (LocalTimestamp (), " timestamp2" )())
97
+ val inFilterWithSubQuery = InSubquery (valueSearchedInSubQuery, listSubQuery)
98
+ val input = Project (Nil , Filter (inFilterWithSubQuery, LocalRelation ()))
99
+
100
+ val plan = Optimize .execute(input.analyze).asInstanceOf [Project ]
101
+
102
+ val lits = literals[Long ](plan)
103
+ assert(lits.size == 3 ) // transformDownWithSubqueries covers the inner timestamp twice
104
+ assert(lits.toSet.size == 1 )
105
+ }
106
+
107
+ test(" analyzer should use consistent timestamps for different timezones" ) {
108
+ val localTimestamps = mapAsScalaMap(ZoneId .SHORT_IDS )
109
+ .map { case (zoneId, _) => Alias (LocalTimestamp (Some (zoneId)), zoneId)() }.toSeq
110
+ val input = Project (localTimestamps, LocalRelation ())
111
+
112
+ val plan = Optimize .execute(input).asInstanceOf [Project ]
113
+
114
+ val lits = literals[Long ](plan)
115
+ assert(lits.size === localTimestamps.size)
116
+ // there are timezones with a 30 or 45 minute offset
117
+ val offsetsFromQuarterHour = lits.map( _ % Duration (15 , MINUTES ).toMicros).toSet
118
+ assert(offsetsFromQuarterHour.size == 1 )
119
+ }
120
+
121
+ test(" analyzer should use consistent timestamps for different timestamp functions" ) {
122
+ val differentTimestamps = Seq (
123
+ Alias (CurrentTimestamp (), " currentTimestamp" )(),
124
+ Alias (Now (), " now" )(),
125
+ Alias (LocalTimestamp (Some (" PLT" )), " localTimestampWithTimezone" )()
126
+ )
127
+ val input = Project (differentTimestamps, LocalRelation ())
128
+
129
+ val plan = Optimize .execute(input).asInstanceOf [Project ]
130
+
131
+ val lits = literals[Long ](plan)
132
+ assert(lits.size === differentTimestamps.size)
133
+ // there are timezones with a 30 or 45 minute offset
134
+ val offsetsFromQuarterHour = lits.map( _ % Duration (15 , MINUTES ).toMicros).toSet
135
+ assert(offsetsFromQuarterHour.size == 1 )
136
+ }
137
+
138
+ private def literals [T ](plan : LogicalPlan ): scala.collection.mutable.ArrayBuffer [T ] = {
139
+ val literals = new scala.collection.mutable.ArrayBuffer [T ]
140
+ plan.transformWithSubqueries { case subQuery =>
141
+ subQuery.transformAllExpressions { case expression : Literal =>
142
+ literals += expression.value.asInstanceOf [T ]
143
+ expression
144
+ }
145
+ }
146
+ literals
147
+ }
105
148
}
0 commit comments