From fcc3d9bc892927af74bc628daab81347f79afd7f Mon Sep 17 00:00:00 2001 From: Cuong Nguyen Date: Mon, 11 Nov 2024 18:36:20 -0800 Subject: [PATCH] [Spark] Pass catalog table to DeltaLog API call sites (#3862) #### Which Delta project/connector is this regarding? - [x] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This PR changes the API of the DeltaLog to take in an extra parameter for table catalog and switching some call sites (more to come) to use the new API version. Delta log API changes + Added `forTable(SparkSession, CatalogTable, Map[String, String])` + Added `forTableWithSnapshot(SparkSession, CatalogTable, Map[String, String])` + Modified `withFreshSnapshot` to take in a catalog table. ## How was this patch tested? Unit tests ## Does this PR introduce _any_ user-facing changes? No --- .../spark/sql/delta/hudi/HudiConverter.scala | 14 +++-- .../icebergShaded/IcebergConverter.scala | 2 +- .../spark/sql/delta/DeltaAnalysis.scala | 5 +- .../org/apache/spark/sql/delta/DeltaLog.scala | 55 ++++++++++++++++--- .../spark/sql/delta/SnapshotManagement.scala | 16 +++--- .../sql/delta/catalog/DeltaTableV2.scala | 5 +- .../sql/delta/commands/DeltaCommand.scala | 15 ++--- .../delta/commands/DeltaGenerateCommand.scala | 10 ++-- .../DescribeDeltaDetailsCommand.scala | 9 ++- .../delta/commands/RestoreTableCommand.scala | 11 +++- .../sql/delta/hooks/CheckpointHook.scala | 2 +- .../spark/sql/delta/DeltaLogSuite.scala | 2 +- 12 files changed, 98 insertions(+), 48 deletions(-) diff --git a/hudi/src/main/scala/org/apache/spark/sql/delta/hudi/HudiConverter.scala b/hudi/src/main/scala/org/apache/spark/sql/delta/hudi/HudiConverter.scala index f793cf890e7..03e06b9fae8 100644 --- a/hudi/src/main/scala/org/apache/spark/sql/delta/hudi/HudiConverter.scala +++ b/hudi/src/main/scala/org/apache/spark/sql/delta/hudi/HudiConverter.scala @@ -177,7 +177,7 @@ class HudiConverter(spark: SparkSession) if (!UniversalFormat.hudiEnabled(snapshotToConvert.metadata)) { return None } - convertSnapshot(snapshotToConvert, None, Option.apply(catalogTable.identifier.table)) + convertSnapshot(snapshotToConvert, None, Some(catalogTable)) } /** @@ -193,7 +193,7 @@ class HudiConverter(spark: SparkSession) if (!UniversalFormat.hudiEnabled(snapshotToConvert.metadata)) { return None } - convertSnapshot(snapshotToConvert, Some(txn), txn.catalogTable.map(_.identifier.table)) + convertSnapshot(snapshotToConvert, Some(txn), txn.catalogTable) } /** @@ -208,11 +208,13 @@ class HudiConverter(spark: SparkSession) private def convertSnapshot( snapshotToConvert: Snapshot, txnOpt: Option[OptimisticTransactionImpl], - tableName: Option[String]): Option[(Long, Long)] = + catalogTable: Option[CatalogTable]): Option[(Long, Long)] = recordFrameProfile("Delta", "HudiConverter.convertSnapshot") { val log = snapshotToConvert.deltaLog - val metaClient = loadTableMetaClient(snapshotToConvert.deltaLog.dataPath.toString, - tableName, snapshotToConvert.metadata.partitionColumns, + val metaClient = loadTableMetaClient( + snapshotToConvert.deltaLog.dataPath.toString, + catalogTable.flatMap(ct => Option(ct.identifier.table)), + snapshotToConvert.metadata.partitionColumns, new HadoopStorageConfiguration(log.newDeltaHadoopConf())) val lastDeltaVersionConverted: Option[Long] = loadLastDeltaVersionConverted(metaClient) val maxCommitsToConvert = @@ -233,7 +235,7 @@ class HudiConverter(spark: SparkSession) try { // TODO: We can optimize this by providing a checkpointHint to getSnapshotAt. Check if // txn.snapshot.version < version. If true, use txn.snapshot's checkpoint as a hint. - Some(log.getSnapshotAt(version)) + Some(log.getSnapshotAt(version, catalogTableOpt = catalogTable)) } catch { // If we can't load the file since the last time Hudi was converted, it's likely that // the commit file expired. Treat this like a new Hudi table conversion. diff --git a/iceberg/src/main/scala/org/apache/spark/sql/delta/icebergShaded/IcebergConverter.scala b/iceberg/src/main/scala/org/apache/spark/sql/delta/icebergShaded/IcebergConverter.scala index ec4385336ab..075e3ddcb1a 100644 --- a/iceberg/src/main/scala/org/apache/spark/sql/delta/icebergShaded/IcebergConverter.scala +++ b/iceberg/src/main/scala/org/apache/spark/sql/delta/icebergShaded/IcebergConverter.scala @@ -312,7 +312,7 @@ class IcebergConverter(spark: SparkSession) try { // TODO: We can optimize this by providing a checkpointHint to getSnapshotAt. Check if // txn.snapshot.version < version. If true, use txn.snapshot's checkpoint as a hint. - Some(log.getSnapshotAt(version)) + Some(log.getSnapshotAt(version, catalogTableOpt = Some(catalogTable))) } catch { // If we can't load the file since the last time Iceberg was converted, it's likely that // the commit file expired. Treat this like a new Iceberg table conversion. diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala index 53d11c7dacf..f20e79a58ed 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala @@ -1144,11 +1144,12 @@ class DeltaAnalysis(session: SparkSession) session, dataSourceV1.options ).foreach { rootSchemaTrackingLocation => assert(dataSourceV1.options.contains("path"), "Path for Delta table must be defined") - val log = DeltaLog.forTable(session, new Path(dataSourceV1.options("path"))) + val tableId = + dataSourceV1.options("path").replace(":", "").replace("/", "_") val sourceIdOpt = dataSourceV1.options.get(DeltaOptions.STREAMING_SOURCE_TRACKING_ID) val schemaTrackingLocation = DeltaSourceMetadataTrackingLog.fullMetadataTrackingLocation( - rootSchemaTrackingLocation, log.tableId, sourceIdOpt) + rootSchemaTrackingLocation, tableId, sourceIdOpt) // Make sure schema location is under checkpoint if (!allowSchemaLocationOutsideOfCheckpoint && !(schemaTrackingLocation.stripPrefix("file:").stripSuffix("/") + "/") diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala index 815946727b6..c8f1b5c9b08 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala @@ -803,6 +803,16 @@ object DeltaLog extends DeltaLogging { } } + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, table: CatalogTable, options: Map[String, String]): DeltaLog = { + apply( + spark, + logPathFor(new Path(table.location)), + options, + Some(table.identifier), + new SystemClock) + } + /** Helper for creating a log for the table. */ def forTable(spark: SparkSession, table: CatalogTable, clock: Clock): DeltaLog = { apply(spark, logPathFor(new Path(table.location)), Some(table.identifier), clock) @@ -818,25 +828,50 @@ object DeltaLog extends DeltaLogging { /** Helper for getting a log, as well as the latest snapshot, of the table */ def forTableWithSnapshot(spark: SparkSession, dataPath: String): (DeltaLog, Snapshot) = - withFreshSnapshot { forTable(spark, new Path(dataPath), _) } + withFreshSnapshot { clock => + (forTable(spark, new Path(dataPath), clock), None) + } /** Helper for getting a log, as well as the latest snapshot, of the table */ def forTableWithSnapshot(spark: SparkSession, dataPath: Path): (DeltaLog, Snapshot) = - withFreshSnapshot { forTable(spark, dataPath, _) } + withFreshSnapshot { clock => + (forTable(spark, dataPath, clock), None) + } /** Helper for getting a log, as well as the latest snapshot, of the table */ def forTableWithSnapshot( spark: SparkSession, - tableName: TableIdentifier): (DeltaLog, Snapshot) = - withFreshSnapshot { forTable(spark, tableName, _) } + tableName: TableIdentifier): (DeltaLog, Snapshot) = { + withFreshSnapshot { clock => + if (DeltaTableIdentifier.isDeltaPath(spark, tableName)) { + (forTable(spark, new Path(tableName.table)), None) + } else { + val catalogTable = spark.sessionState.catalog.getTableMetadata(tableName) + (forTable(spark, catalogTable, clock), Some(catalogTable)) + } + } + } /** Helper for getting a log, as well as the latest snapshot, of the table */ def forTableWithSnapshot( spark: SparkSession, dataPath: Path, options: Map[String, String]): (DeltaLog, Snapshot) = - withFreshSnapshot { - apply(spark, logPathFor(dataPath), options, initialTableIdentifier = None, _) + withFreshSnapshot { clock => + val deltaLog = + apply(spark, logPathFor(dataPath), options, initialTableIdentifier = None, clock) + (deltaLog, None) + } + + /** Helper for getting a log, as well as the latest snapshot, of the table */ + def forTableWithSnapshot( + spark: SparkSession, + table: CatalogTable, + options: Map[String, String]): (DeltaLog, Snapshot) = + withFreshSnapshot { clock => + val deltaLog = + apply(spark, logPathFor(new Path(table.location)), options, Some(table.identifier), clock) + (deltaLog, Some(table)) } /** @@ -844,11 +879,13 @@ object DeltaLog extends DeltaLogging { * partially applied DeltaLog.forTable call, which we can then wrap around with a * snapshot update. We use the system clock to avoid back-to-back updates. */ - private[delta] def withFreshSnapshot(thunk: Clock => DeltaLog): (DeltaLog, Snapshot) = { + private[delta] def withFreshSnapshot( + thunk: Clock => (DeltaLog, Option[CatalogTable])): (DeltaLog, Snapshot) = { val clock = new SystemClock val ts = clock.getTimeMillis() - val deltaLog = thunk(clock) - val snapshot = deltaLog.update(checkIfUpdatedSinceTs = Some(ts)) + val (deltaLog, catalogTableOpt) = thunk(clock) + val snapshot = + deltaLog.update(checkIfUpdatedSinceTs = Some(ts), catalogTableOpt = catalogTableOpt) (deltaLog, snapshot) } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala b/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala index 333835d6786..ab0a3cc0a23 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala @@ -43,6 +43,7 @@ import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.internal.MDC import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -1016,7 +1017,7 @@ trait SnapshotManagement { self: DeltaLog => def update( stalenessAcceptable: Boolean = false, checkIfUpdatedSinceTs: Option[Long] = None, - tableIdentifierOpt: Option[TableIdentifier] = None): Snapshot = { + catalogTableOpt: Option[CatalogTable] = None): Snapshot = { val startTimeMs = System.currentTimeMillis() // currentSnapshot is volatile. Make a local copy of it at the start of the update call, so // that there's no chance of a race condition changing the snapshot partway through the update. @@ -1049,7 +1050,7 @@ trait SnapshotManagement { self: DeltaLog => withSnapshotLockInterruptibly { val newSnapshot = updateInternal( isAsync = false, - tableIdentifierOpt) + catalogTableOpt.map(_.identifier)) sendEvent(newSnapshot = capturedSnapshot.snapshot) newSnapshot } @@ -1067,7 +1068,7 @@ trait SnapshotManagement { self: DeltaLog => interruptOnCancel = true) tryUpdate( isAsync = true, - tableIdentifierOpt) + catalogTableOpt.map(_.identifier)) } } catch { case NonFatal(e) if !Utils.isTesting => @@ -1338,12 +1339,12 @@ trait SnapshotManagement { self: DeltaLog => def getSnapshotAt( version: Long, lastCheckpointHint: Option[CheckpointInstance] = None, - tableIdentifierOpt: Option[TableIdentifier] = None): Snapshot = { + catalogTableOpt: Option[CatalogTable] = None): Snapshot = { getSnapshotAt( version, lastCheckpointHint, lastCheckpointProvider = None, - tableIdentifierOpt) + catalogTableOpt) } /** @@ -1354,7 +1355,7 @@ trait SnapshotManagement { self: DeltaLog => version: Long, lastCheckpointHint: Option[CheckpointInstance], lastCheckpointProvider: Option[CheckpointProvider], - tableIdentifierOpt: Option[TableIdentifier]): Snapshot = { + catalogTableOpt: Option[CatalogTable]): Snapshot = { // See if the version currently cached on the cluster satisfies the requirement val currentSnapshot = unsafeVolatileSnapshot @@ -1363,7 +1364,7 @@ trait SnapshotManagement { self: DeltaLog => // upper bound. currentSnapshot } else { - val latestSnapshot = update(tableIdentifierOpt = tableIdentifierOpt) + val latestSnapshot = update(catalogTableOpt = catalogTableOpt) if (latestSnapshot.version < version) { throwNonExistentVersionError(version) } @@ -1385,6 +1386,7 @@ trait SnapshotManagement { self: DeltaLog => .map(manuallyLoadCheckpoint) lastCheckpointInfoForListing -> None } + val tableIdentifierOpt = catalogTableOpt.map(_.identifier) val logSegmentOpt = createLogSegment( versionToLoad = Some(version), oldCheckpointProviderOpt = lastCheckpointProviderOpt, diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala index f5792b0060e..d88f4c2c830 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala @@ -156,11 +156,12 @@ case class DeltaTableV2( "queriedVersion" -> version, "accessType" -> accessType )) - deltaLog.getSnapshotAt(version) + deltaLog.getSnapshotAt(version, catalogTableOpt = catalogTable) }.getOrElse( deltaLog.update( stalenessAcceptable = true, - checkIfUpdatedSinceTs = Some(creationTimeMs) + checkIfUpdatedSinceTs = Some(creationTimeMs), + catalogTableOpt = catalogTable ) ) } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaCommand.scala index c3c37ab1738..62f88328b9f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaCommand.scala @@ -224,7 +224,7 @@ trait DeltaCommand extends DeltaLogging { /** * Utility method to return the [[DeltaLog]] of an existing Delta table referred - * by either the given [[path]] or [[tableIdentifier]. + * by either the given [[path]] or [[tableIdentifier]]. * * @param spark [[SparkSession]] reference to use. * @param path Table location. Expects a non-empty [[tableIdentifier]] or [[path]]. @@ -241,18 +241,18 @@ trait DeltaCommand extends DeltaLogging { tableIdentifier: Option[TableIdentifier], operationName: String, hadoopConf: Map[String, String] = Map.empty): DeltaLog = { - val tablePath = + val (deltaLog, catalogTable) = if (path.nonEmpty) { - new Path(path.get) + (DeltaLog.forTable(spark, new Path(path.get), hadoopConf), None) } else if (tableIdentifier.nonEmpty) { val sessionCatalog = spark.sessionState.catalog lazy val metadata = sessionCatalog.getTableMetadata(tableIdentifier.get) DeltaTableIdentifier(spark, tableIdentifier.get) match { case Some(id) if id.path.nonEmpty => - new Path(id.path.get) + (DeltaLog.forTable(spark, new Path(id.path.get), hadoopConf), None) case Some(id) if id.table.nonEmpty => - new Path(metadata.location) + (DeltaLog.forTable(spark, metadata, hadoopConf), Some(metadata)) case _ => if (metadata.tableType == CatalogTableType.VIEW) { throw DeltaErrors.viewNotSupported(operationName) @@ -264,8 +264,9 @@ trait DeltaCommand extends DeltaLogging { } val startTime = Some(System.currentTimeMillis) - val deltaLog = DeltaLog.forTable(spark, tablePath, hadoopConf) - if (deltaLog.update(checkIfUpdatedSinceTs = startTime).version < 0) { + if (deltaLog + .update(checkIfUpdatedSinceTs = startTime, catalogTableOpt = catalogTable) + .version < 0) { throw DeltaErrors.notADeltaTableException( operationName, DeltaTableIdentifier(path, tableIdentifier)) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaGenerateCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaGenerateCommand.scala index 42b3914ca25..b9842390740 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaGenerateCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaGenerateCommand.scala @@ -39,14 +39,16 @@ case class DeltaGenerateCommand( throw DeltaErrors.unsupportedGenerateModeException(modeName) } - val tablePath = DeltaTableIdentifier(sparkSession, tableId) match { + val deltaLog = DeltaTableIdentifier(sparkSession, tableId) match { case Some(id) if id.path.isDefined => - new Path(id.path.get) + DeltaLog.forTable(sparkSession, new Path(id.path.get), options) case _ => - new Path(sparkSession.sessionState.catalog.getTableMetadata(tableId).location) + DeltaLog.forTable( + sparkSession, + sparkSession.sessionState.catalog.getTableMetadata(tableId), + options) } - val deltaLog = DeltaLog.forTable(sparkSession, tablePath, options) if (!deltaLog.tableExists) { throw DeltaErrors.notADeltaTableException("GENERATE") } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DescribeDeltaDetailsCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DescribeDeltaDetailsCommand.scala index 3642d6c0456..d79a2c2d0df 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DescribeDeltaDetailsCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DescribeDeltaDetailsCommand.scala @@ -84,15 +84,14 @@ case class DescribeDeltaDetailCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val tableMetadata = getTableCatalogTable(child, DescribeDeltaDetailCommand.CMD_NAME) val (_, path) = getTablePathOrIdentifier(child, DescribeDeltaDetailCommand.CMD_NAME) - val basePath = tableMetadata match { - case Some(metadata) => new Path(metadata.location) - case _ if path.isDefined => new Path(path.get) + val deltaLog = (tableMetadata, path) match { + case (Some(metadata), _) => DeltaLog.forTable(sparkSession, metadata, hadoopConf) + case (_, Some(path)) => DeltaLog.forTable(sparkSession, new Path(path), hadoopConf) case _ => throw DeltaErrors.missingTableIdentifierException(DescribeDeltaDetailCommand.CMD_NAME) } - val deltaLog = DeltaLog.forTable(sparkSession, basePath, hadoopConf) recordDeltaOperation(deltaLog, "delta.ddl.describeDetails") { - val snapshot = deltaLog.update() + val snapshot = deltaLog.update(catalogTableOpt = tableMetadata) if (snapshot.version == -1) { if (path.nonEmpty) { val fs = new Path(path.get).getFileSystem(deltaLog.newDeltaHadoopConf()) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala index 6bed351e488..ec086f4799c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala @@ -92,6 +92,7 @@ case class RestoreTableCommand(sourceTable: DeltaTableV2) override def run(spark: SparkSession): Seq[Row] = { val deltaLog = sourceTable.deltaLog + val catalogTableOpt = sourceTable.catalogTable val version = sourceTable.timeTravelOpt.get.version val timestamp = getTimestamp() recordDeltaOperation(deltaLog, "delta.restore") { @@ -105,14 +106,18 @@ case class RestoreTableCommand(sourceTable: DeltaTableV2) .version } - val latestVersion = deltaLog.update().version + val latestVersion = deltaLog + .update(catalogTableOpt = catalogTableOpt) + .version require(versionToRestore < latestVersion, s"Version to restore ($versionToRestore)" + s"should be less then last available version ($latestVersion)") - deltaLog.withNewTransaction(sourceTable.catalogTable) { txn => + deltaLog.withNewTransaction(catalogTableOpt) { txn => val latestSnapshot = txn.snapshot - val snapshotToRestore = deltaLog.getSnapshotAt(versionToRestore) + val snapshotToRestore = deltaLog.getSnapshotAt( + versionToRestore, + catalogTableOpt = txn.catalogTable) val latestSnapshotFiles = latestSnapshot.allFiles val snapshotToRestoreFiles = snapshotToRestore.allFiles diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/hooks/CheckpointHook.scala b/spark/src/main/scala/org/apache/spark/sql/delta/hooks/CheckpointHook.scala index c02071e9806..51993e230ed 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/hooks/CheckpointHook.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/hooks/CheckpointHook.scala @@ -40,7 +40,7 @@ object CheckpointHook extends PostCommitHook { committedVersion, lastCheckpointHint = None, lastCheckpointProvider = Some(cp), - tableIdentifierOpt = txn.catalogTable.map(_.identifier)) + catalogTableOpt = txn.catalogTable) txn.deltaLog.checkpoint(snapshotToCheckpoint, txn.catalogTable) } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala index 93a7e275378..1423d85db0e 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala @@ -594,7 +594,7 @@ class DeltaLogSuite extends QueryTest Iterator(JsonUtils.toJson(add.wrap)), overwrite = false, deltaLog.newDeltaHadoopConf()) - deltaLog + (deltaLog, None) } assert(snapshot.version === 0)