Skip to content

Commit afef6ce

Browse files
committed
Implement years, months, hours, and bucket transforms for tests.
1 parent fd65dd4 commit afef6ce

File tree

4 files changed

+40
-17
lines changed

4 files changed

+40
-17
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.connector
1919

20-
import java.time.ZoneId
20+
import java.time.{Instant, ZoneId}
21+
import java.time.temporal.ChronoUnit
2122
import java.util
2223

2324
import scala.collection.JavaConverters._
@@ -28,7 +29,7 @@ import org.scalatest.Assertions._
2829
import org.apache.spark.sql.catalyst.InternalRow
2930
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3031
import org.apache.spark.sql.connector.catalog._
31-
import org.apache.spark.sql.connector.expressions.{DaysTransform, IdentityTransform, Transform}
32+
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
3233
import org.apache.spark.sql.connector.read._
3334
import org.apache.spark.sql.connector.write._
3435
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
@@ -48,11 +49,15 @@ class InMemoryTable(
4849
private val allowUnsupportedTransforms =
4950
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
5051

51-
partitioning.foreach { t =>
52-
if (!t.isInstanceOf[IdentityTransform] && !t.isInstanceOf[DaysTransform] &&
53-
!allowUnsupportedTransforms) {
54-
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform or DaysTransform")
55-
}
52+
partitioning.foreach {
53+
case _: IdentityTransform =>
54+
case _: YearsTransform =>
55+
case _: MonthsTransform =>
56+
case _: DaysTransform =>
57+
case _: HoursTransform =>
58+
case _: BucketTransform =>
59+
case t if !allowUnsupportedTransforms =>
60+
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
5661
}
5762

5863
// The key `Seq[Any]` is the partition values.
@@ -69,6 +74,9 @@ class InMemoryTable(
6974
}
7075
}
7176

77+
private val UTC = ZoneId.of("UTC")
78+
private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate
79+
7280
private def getKey(row: InternalRow): Seq[Any] = {
7381
def extractor(
7482
fieldNames: Array[String],
@@ -91,13 +99,36 @@ class InMemoryTable(
9199
partitioning.map {
92100
case IdentityTransform(ref) =>
93101
extractor(ref.fieldNames, schema, row)._1
102+
case YearsTransform(ref) =>
103+
extractor(ref.fieldNames, schema, row) match {
104+
case (days: Int, DateType) =>
105+
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
106+
case (micros: Long, TimestampType) =>
107+
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
108+
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
109+
}
110+
case MonthsTransform(ref) =>
111+
extractor(ref.fieldNames, schema, row) match {
112+
case (days: Int, DateType) =>
113+
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
114+
case (micros: Long, TimestampType) =>
115+
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
116+
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate)
117+
}
94118
case DaysTransform(ref) =>
95119
extractor(ref.fieldNames, schema, row) match {
96120
case (days, DateType) =>
97121
days
98122
case (micros: Long, TimestampType) =>
99-
DateTimeUtils.microsToDays(micros, ZoneId.of("UTC"))
123+
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
124+
}
125+
case HoursTransform(ref) =>
126+
extractor(ref.fieldNames, schema, row) match {
127+
case (micros: Long, TimestampType) =>
128+
ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
100129
}
130+
case BucketTransform(numBuckets, ref) =>
131+
(extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets
101132
}
102133
}
103134

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ case class BatchScanExec(
4040

4141
override def hashCode(): Int = batch.hashCode()
4242

43-
override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()
43+
@transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()
4444

4545
override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()
4646

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
336336
spark.table("source")
337337
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
338338
.writeTo("testcat.table_name")
339-
.tableProperty("allow-unsupported-transforms", "true")
340339
.partitionedBy(years($"ts"))
341340
.create()
342341

@@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
350349
spark.table("source")
351350
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
352351
.writeTo("testcat.table_name")
353-
.tableProperty("allow-unsupported-transforms", "true")
354352
.partitionedBy(months($"ts"))
355353
.create()
356354

@@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
364362
spark.table("source")
365363
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
366364
.writeTo("testcat.table_name")
367-
.tableProperty("allow-unsupported-transforms", "true")
368365
.partitionedBy(days($"ts"))
369366
.create()
370367

@@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
378375
spark.table("source")
379376
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
380377
.writeTo("testcat.table_name")
381-
.tableProperty("allow-unsupported-transforms", "true")
382378
.partitionedBy(hours($"ts"))
383379
.create()
384380

@@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
391387
test("Create: partitioned by bucket(4, id)") {
392388
spark.table("source")
393389
.writeTo("testcat.table_name")
394-
.tableProperty("allow-unsupported-transforms", "true")
395390
.partitionedBy(bucket(4, $"id"))
396391
.create()
397392

@@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
596591
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
597592
lit("America/Los_Angeles") as "timezone"))
598593
.writeTo("testcat.table_name")
599-
.tableProperty("allow-unsupported-transforms", "true")
600594
.partitionedBy(
601595
years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"),
602596
years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified")
@@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
624618
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
625619
lit("America/Los_Angeles") as "timezone"))
626620
.writeTo("testcat.table_name")
627-
.tableProperty("allow-unsupported-transforms", "true")
628621
.partitionedBy(bucket(4, $"ts.timezone"))
629622
.create()
630623

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,6 @@ class DataSourceV2SQLSuite
16501650
"""
16511651
|CREATE TABLE testcat.t (id int, `a.b` string) USING foo
16521652
|CLUSTERED BY (`a.b`) INTO 4 BUCKETS
1653-
|OPTIONS ('allow-unsupported-transforms'=true)
16541653
""".stripMargin)
16551654

16561655
val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog]

0 commit comments

Comments
 (0)