17
17
18
18
package org .apache .spark .sql .connector
19
19
20
- import java .time .ZoneId
20
+ import java .time .{Instant , ZoneId }
21
+ import java .time .temporal .ChronoUnit
21
22
import java .util
22
23
23
24
import scala .collection .JavaConverters ._
@@ -28,7 +29,7 @@ import org.scalatest.Assertions._
28
29
import org .apache .spark .sql .catalyst .InternalRow
29
30
import org .apache .spark .sql .catalyst .util .DateTimeUtils
30
31
import org .apache .spark .sql .connector .catalog ._
31
- import org .apache .spark .sql .connector .expressions .{DaysTransform , IdentityTransform , Transform }
32
+ import org .apache .spark .sql .connector .expressions .{BucketTransform , DaysTransform , HoursTransform , IdentityTransform , MonthsTransform , Transform , YearsTransform }
32
33
import org .apache .spark .sql .connector .read ._
33
34
import org .apache .spark .sql .connector .write ._
34
35
import org .apache .spark .sql .sources .{And , EqualTo , Filter , IsNotNull }
@@ -48,11 +49,15 @@ class InMemoryTable(
48
49
private val allowUnsupportedTransforms =
49
50
properties.getOrDefault(" allow-unsupported-transforms" , " false" ).toBoolean
50
51
51
- partitioning.foreach { t =>
52
- if (! t.isInstanceOf [IdentityTransform ] && ! t.isInstanceOf [DaysTransform ] &&
53
- ! allowUnsupportedTransforms) {
54
- throw new IllegalArgumentException (s " Transform $t must be IdentityTransform or DaysTransform " )
55
- }
52
+ partitioning.foreach {
53
+ case _ : IdentityTransform =>
54
+ case _ : YearsTransform =>
55
+ case _ : MonthsTransform =>
56
+ case _ : DaysTransform =>
57
+ case _ : HoursTransform =>
58
+ case _ : BucketTransform =>
59
+ case t if ! allowUnsupportedTransforms =>
60
+ throw new IllegalArgumentException (s " Transform $t is not a supported transform " )
56
61
}
57
62
58
63
// The key `Seq[Any]` is the partition values.
@@ -69,6 +74,9 @@ class InMemoryTable(
69
74
}
70
75
}
71
76
77
+ private val UTC = ZoneId .of(" UTC" )
78
+ private val EPOCH_LOCAL_DATE = Instant .EPOCH .atZone(UTC ).toLocalDate
79
+
72
80
private def getKey (row : InternalRow ): Seq [Any ] = {
73
81
def extractor (
74
82
fieldNames : Array [String ],
@@ -91,13 +99,36 @@ class InMemoryTable(
91
99
partitioning.map {
92
100
case IdentityTransform (ref) =>
93
101
extractor(ref.fieldNames, schema, row)._1
102
+ case YearsTransform (ref) =>
103
+ extractor(ref.fieldNames, schema, row) match {
104
+ case (days : Int , DateType ) =>
105
+ ChronoUnit .YEARS .between(EPOCH_LOCAL_DATE , DateTimeUtils .daysToLocalDate(days))
106
+ case (micros : Long , TimestampType ) =>
107
+ val localDate = DateTimeUtils .microsToInstant(micros).atZone(UTC ).toLocalDate
108
+ ChronoUnit .YEARS .between(EPOCH_LOCAL_DATE , localDate)
109
+ }
110
+ case MonthsTransform (ref) =>
111
+ extractor(ref.fieldNames, schema, row) match {
112
+ case (days : Int , DateType ) =>
113
+ ChronoUnit .MONTHS .between(EPOCH_LOCAL_DATE , DateTimeUtils .daysToLocalDate(days))
114
+ case (micros : Long , TimestampType ) =>
115
+ val localDate = DateTimeUtils .microsToInstant(micros).atZone(UTC ).toLocalDate
116
+ ChronoUnit .MONTHS .between(EPOCH_LOCAL_DATE , localDate)
117
+ }
94
118
case DaysTransform (ref) =>
95
119
extractor(ref.fieldNames, schema, row) match {
96
120
case (days, DateType ) =>
97
121
days
98
122
case (micros : Long , TimestampType ) =>
99
- DateTimeUtils .microsToDays(micros, ZoneId .of(" UTC" ))
123
+ ChronoUnit .DAYS .between(Instant .EPOCH , DateTimeUtils .microsToInstant(micros))
124
+ }
125
+ case HoursTransform (ref) =>
126
+ extractor(ref.fieldNames, schema, row) match {
127
+ case (micros : Long , TimestampType ) =>
128
+ ChronoUnit .HOURS .between(Instant .EPOCH , DateTimeUtils .microsToInstant(micros))
100
129
}
130
+ case BucketTransform (numBuckets, ref) =>
131
+ (extractor(ref.fieldNames, schema, row).hashCode() & Integer .MAX_VALUE ) % numBuckets
101
132
}
102
133
}
103
134
0 commit comments