Skip to content

Commit

Permalink
[SPARK] Add common sql test utility methods to DeltaSQLTestUtils (d…
Browse files Browse the repository at this point in the history
…elta-io#4131)

<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [X] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description
Adding the most commonly used methods in testing; such as getting the
delta log, snapshot, protocol, stats, columns etc in `DeltaSQLTestUtils`
so that we don't have to keep reimplementing them over and over again
and just use this trait instead.

## How was this patch tested?
No testing required.

## Does this PR introduce _any_ user-facing changes?
No.
  • Loading branch information
stefankandic authored Feb 20, 2025
1 parent 71cf788 commit e621c5d
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ class DeltaVariantSuite

import testImplicits._

private def getProtocolForTable(table: String): Protocol = {
val deltaLog = DeltaLog.forTable(spark, TableIdentifier(table))
deltaLog.unsafeVolatileSnapshot.protocol
}

private def assertVariantTypeTableFeatures(
tableName: String,
expectPreviewFeature: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@ package org.apache.spark.sql.delta.test

import java.io.File

import scala.util.Random

import org.apache.spark.sql.delta.{DeltaColumnMappingTestUtilsBase, DeltaLog, DeltaTable, Snapshot, TableFeature}
import org.apache.spark.sql.delta.actions.Protocol
import org.apache.spark.sql.delta.stats.{DeltaStatistics, PreparedDeltaFileIndex}
import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}

import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

trait DeltaSQLTestUtils extends SQLTestUtils {
Expand Down Expand Up @@ -74,4 +84,224 @@ trait DeltaSQLTestUtils extends SQLTestUtils {
files.foreach(Utils.deleteRecursively)
}
}

/** Returns random alphanumberic string to be used as a unique table name. */
def uniqueTableName: String = Random.alphanumeric.take(10).mkString

/** Gets the latest snapshot of the table. */
def getSnapshot(tableName: String): Snapshot = {
DeltaLog.forTable(spark, TableIdentifier(tableName)).update()
}

/** Gets the table protocol of the latest snapshot. */
def getProtocolForTable(tableName: String): Protocol = {
getSnapshot(tableName).protocol
}
/** Gets the `StructField` of `columnPath`. */
final def getColumnField(schema: StructType, columnPath: Seq[String]): StructField = {
schema.findNestedField(columnPath, includeCollections = true).get._2
}

/** Gets the `StructField` of `columnName`. */
def getColumnField(tableName: String, columnName: String): StructField = {
val deltaLog = DeltaLog.forTable(spark, TableIdentifier(tableName))
getColumnField(deltaLog.update().schema, columnName.split("\\."))
}

/** Gets the `DataType` of `columnPath`. */
def getColumnType(schema: StructType, columnPath: Seq[String]): DataType = {
getColumnField(schema, columnPath).dataType
}

/** Gets the `DataType` of `columnName`. */
def getColumnType(tableName: String, columnName: String): DataType = {
getColumnField(tableName, columnName).dataType
}

/**
* Gets the stats fields from the AddFiles of `snapshot`. The stats are ordered by the
* modification time of the files they are associated with.
*/
def getUnvalidatedStatsOrderByFileModTime(snapshot: Snapshot): Array[JsonNode] = {
snapshot.allFiles
.orderBy("modificationTime")
.collect()
.map(file => new ObjectMapper().readTree(file.stats))
}

/**
* Gets the stats fields from the AddFiles of `tableName`. The stats are ordered by the
* modification time of the files they are associated with.
*/
def getUnvalidatedStatsOrderByFileModTime(tableName: String): Array[JsonNode] =
getUnvalidatedStatsOrderByFileModTime(getSnapshot(tableName))

/** Gets the physical column path if there is column mapping metadata in the schema. */
def getPhysicalColumnPath(tableSchema: StructType, columnName: String): Seq[String] = {
new DeltaColumnMappingTestUtilsBase {}.getPhysicalPathForStats(
columnName.split("\\."), tableSchema
).get
}

/** Gets the value of a specified field from `stats` JSON node if it exists. */
def getStatFieldOpt(stats: JsonNode, path: Seq[String]): Option[JsonNode] =
path.foldLeft(Option(stats)) {
case (Some(node), key) if node.has(key) => Option(node.get(key))
case _ => None
}

/** Gets the min/max stats of `columnName` from `stats` if they exist. */
private def getMinMaxStatsOpt(
tableName: String,
stats: JsonNode,
columnName: String): (Option[String], Option[String]) = {
val columnPath = columnName.split('.')
val schema = getSnapshot(tableName).schema
val colType = getColumnType(schema, columnPath)
assert(colType.isInstanceOf[StringType], s"Expected StringType, got $colType")

val physicalColumnPath = getPhysicalColumnPath(schema, columnName)
val minStatsPath = DeltaStatistics.MIN +: physicalColumnPath
val maxStatsPath = DeltaStatistics.MAX +: physicalColumnPath
(
getStatFieldOpt(stats, minStatsPath).map(_.asText()),
getStatFieldOpt(stats, maxStatsPath).map(_.asText()))
}

/** Gets the min/max stats of `columnName` from `stats`. */
def getMinMaxStats(
tableName: String,
stats: JsonNode,
columnName: String): (String, String) = {
val (minOpt, maxOpt) = getMinMaxStatsOpt(tableName, stats, columnName)
(minOpt.get, maxOpt.get)
}

/** Verifies whether there are min/max stats of `columnName` in `stats`. */
def assertMinMaxStatsPresence(
tableName: String,
stats: JsonNode,
columnName: String,
expectStats: Boolean): Unit = {
val (minStats, maxStats) = getMinMaxStatsOpt(tableName, stats, columnName)
assert(minStats.isDefined === expectStats)
assert(maxStats.isDefined === expectStats)
}

/** Verifies min/max stats values of `columnName` in `stats`. */
def assertMinMaxStats(
tableName: String,
stats: JsonNode,
columnName: String,
expectedMin: String,
expectedMax: String): Unit = {
val (min, max) =
getMinMaxStats(tableName, stats, columnName)
assert(min === expectedMin, s"Expected $expectedMin, got $min")
assert(max === expectedMax, s"Expected $expectedMax, got $max")
}

/** Verifies minReaderVersion and minWriterVersion of the protocol. */
def assertProtocolVersion(
protocol: Protocol,
minReaderVersion: Int,
minWriterVersion: Int): Unit = {
assert(protocol.minReaderVersion === minReaderVersion)
assert(protocol.minWriterVersion === minWriterVersion)
}

/** Verifies column is of expected data type. */
def assertColumnDataType(
tableName: String,
columnName: String,
expectedDataType: DataType): Unit = {
assert(getColumnType(tableName, columnName) === expectedDataType)
}

/** Verifies `columnName` does not exist in `tableName`. */
def assertColumnNotExist(tableName: String, columnName: String): Unit = {
val e = intercept[AnalysisException] {
sql(s"SELECT $columnName FROM $tableName")
}
assert(e.getMessage.contains(s"`$columnName` cannot be resolved"))
}

/**
* Runs `select` query on `tableName` with `predicate` and verifies the number of rows returned
* and files read.
*/
def assertSelectQueryResults(
tableName: String,
predicate: String,
numRows: Int,
numFilesRead: Int): Unit = {
val query = sql(s"SELECT * FROM $tableName WHERE $predicate")
assertSelectQueryResults(query, numRows, numFilesRead)
}

/**
* Runs `query` and verifies the number of rows returned
* and files read.
*/
def assertSelectQueryResults(
query: DataFrame,
numRows: Int,
numFilesRead: Int): Unit = {
assert(query.count() === numRows, s"Expected $numRows rows, got ${query.count()}")
val filesRead = getNumReadFiles(query)
assert(filesRead === numFilesRead, s"Expected $numFilesRead files read, got $filesRead")
}

/** Returns the number of read files by the query with given query text. */
def getNumReadFiles(queryText: String): Int = {
getNumReadFiles(sql(queryText))
}

/** Returns the number of read files by the given data frame query. */
def getNumReadFiles(df: DataFrame): Int = {
val deltaScans = df.queryExecution.optimizedPlan.collect {
case DeltaTable(prepared: PreparedDeltaFileIndex) => prepared.preparedScan
}
assert(deltaScans.size == 1)
deltaScans.head.files.length
}

/** Drops `columnName` from `tableName`. */
def dropColumn(tableName: String, columnName: String): Unit = {
sql(s"ALTER TABLE $tableName DROP COLUMN $columnName")
assertColumnNotExist(tableName, columnName)
}

/** Changes `columnName` to `newType` */
def alterColumnType(tableName: String, columnName: String, newType: String): Unit = {
sql(s"ALTER TABLE $tableName ALTER COLUMN $columnName TYPE $newType")
}

/** Whether the table protocol supports the given table feature. */
def isFeatureSupported(tableName: String, tableFeature: TableFeature): Boolean = {
val protocol = getProtocolForTable(tableName)
protocol.isFeatureSupported(tableFeature)
}

/** Whether the table protocol supports the given table feature. */
def isFeatureSupported(tableName: String, featureName: String): Boolean = {
val protocol = getProtocolForTable(tableName)
protocol.readerFeatureNames.contains(featureName) ||
protocol.writerFeatureNames.contains(featureName)
}

/** Enables table feature for `tableName` and given `featureName`. */
def enableTableFeature(tableName: String, featureName: String): Unit = {
sql(s"""
|ALTER TABLE $tableName
|SET TBLPROPERTIES('delta.feature.$featureName' = 'supported')
|""".stripMargin)
assert(isFeatureSupported(tableName, featureName))
}

/** Drops table feature for `tableName` and `featureName`. */
def dropTableFeature(tableName: String, featureName: String): Unit = {
sql(s"ALTER TABLE $tableName DROP FEATURE `$featureName`")
assert(!isFeatureSupported(tableName, featureName))
}
}

0 comments on commit e621c5d

Please sign in to comment.