Skip to content

Commit

Permalink
[Spark] Fix a data loss bug in MergeIntoCommand
Browse files Browse the repository at this point in the history
This is a cherry-pick of delta-io#2128 to the master branch.

#### Which Delta project/connector is this regarding?

-Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description

Fix a data loss bug in MergeIntoCommand.
It's caused by using different spark session config object for PreprocessTableMerge and MergeIntoCommand, which is possible when multiple spark sessions are running concurrently.
If source dataframe has more columns than target table, auto schema merge feature adds additional nullable column to
target table schema. The updated output projection built in PreprocessTableMerge, so `matchedClauses` and `notMatchedClauses` contains the addtional columns, but target table schema in MergeIntoCommand doesn't have it.

As a result, the following index doesn't indicate the delete flag column index, which is `numFields - 2`.
```
      def shouldDeleteRow(row: InternalRow): Boolean =
        row.getBoolean(outputRowEncoder.schema.fields.size)
```

row.getBoolean returns `getByte() != 0`, which causes dropping rows randomly.
- matched rows in target table loss

Also as autoMerge doesn't work
- newly added column data in source df loss.

The fix makes sure MergeIntoCommand uses the same spark session / config object.

Fixes delta-io#2104

I confirmed that delta-io#2104 is fixed with the change.
I confirmed the following by debug log message without the change:

1. matchedClauses has more columns after processRow
2. row.getBoolean(outputRowEncoder.schema.fields.size) refers random column value (It's Unsafe read)
3. canMergeSchema in MergeIntoCommand is false, it was true in PreprocessTableMerge

## Does this PR introduce _any_ user-facing changes?
Yes, fixes the data loss issue

Closes delta-io#2162

Co-authored-by: Chungmin Lee <lee@chungmin.dev>
Signed-off-by: Johan Lasperas <johan.lasperas@databricks.com>
GitOrigin-RevId: 49acacf8ff1c71d7e6bcb2dc2f709c325211430a
  • Loading branch information
2 people authored and vkorukanti committed Nov 27, 2023
1 parent b92885c commit 683a730
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 70 deletions.
30 changes: 30 additions & 0 deletions python/delta/tests/test_deltatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import unittest
import os
from multiprocessing.pool import ThreadPool
from typing import List, Set, Dict, Optional, Any, Callable, Union, Tuple

from pyspark.sql import DataFrame, Row
Expand Down Expand Up @@ -471,6 +472,35 @@ def reset_table() -> None:
with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
dt.merge(source, "key = k").whenNotMatchedBySourceDelete(1) # type: ignore[arg-type]

def test_merge_with_inconsistent_sessions(self) -> None:
source_path = os.path.join(self.tempFile, "source")
target_path = os.path.join(self.tempFile, "target")
spark = self.spark

def f(spark):
spark.range(20) \
.withColumn("x", col("id")) \
.withColumn("y", col("id")) \
.write.mode("overwrite").format("delta").save(source_path)
spark.range(1) \
.withColumn("x", col("id")) \
.write.mode("overwrite").format("delta").save(target_path)
target = DeltaTable.forPath(spark, target_path)
source = spark.read.format("delta").load(source_path).alias("s")
target.alias("t") \
.merge(source, "t.id = s.id") \
.whenMatchedUpdate(set={"t.x": "t.x + 1"}) \
.whenNotMatchedInsertAll() \
.execute()
assert(spark.read.format("delta").load(target_path).count() == 20)

pool = ThreadPool(3)
spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true")
try:
pool.starmap(f, [(spark,)])
finally:
spark.conf.unset("spark.databricks.delta.schema.autoMerge.enabled")

def test_history(self) -> None:
self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
self.__overwriteDeltaTable([('a', 3), ('b', 2), ('c', 1)])
Expand Down
46 changes: 25 additions & 21 deletions spark/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.collection.JavaConverters._
import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaErrors, PostHocResolveUpCast, PreprocessTableMerge}
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.DeltaViewHelper
import org.apache.spark.sql.delta.commands.MergeIntoCommand
import org.apache.spark.sql.delta.util.AnalysisHelper
Expand Down Expand Up @@ -265,30 +266,33 @@ class DeltaMergeBuilder private(
*/
def execute(): Unit = improveUnsupportedOpError {
val sparkSession = targetTable.toDF.sparkSession
// Note: We are explicitly resolving DeltaMergeInto plan rather than going to through the
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as explained
// in the function `mergePlan` and https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
withActiveSession(sparkSession) {
// Note: We are explicitly resolving DeltaMergeInto plan rather than going to through the
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as
// explained in the function `mergePlan` and
// https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
DeltaMergeInto.resolveReferencesAndSchema(mergePlan, sparkSession.sessionState.conf)(
tryResolveReferencesForExpressions(sparkSession))
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
}
val strippedMergeInto = resolvedMergeInto.copy(
target = DeltaViewHelper.stripTempViewForMerge(resolvedMergeInto.target, SQLConf.get)
)
// Preprocess the actions and verify
var mergeIntoCommand =
PreprocessTableMerge(sparkSession.sessionState.conf)(strippedMergeInto)
// Resolve UpCast expressions that `PreprocessTableMerge` may have introduced.
mergeIntoCommand = PostHocResolveUpCast(sparkSession).apply(mergeIntoCommand)
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.asInstanceOf[MergeIntoCommand].run(sparkSession)
}
val strippedMergeInto = resolvedMergeInto.copy(
target = DeltaViewHelper.stripTempViewForMerge(resolvedMergeInto.target, SQLConf.get)
)
// Preprocess the actions and verify
var mergeIntoCommand =
PreprocessTableMerge(sparkSession.sessionState.conf)(strippedMergeInto)
// Resolve UpCast expressions that `PreprocessTableMerge` may have introduced.
mergeIntoCommand = PostHocResolveUpCast(sparkSession).apply(mergeIntoCommand)
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.asInstanceOf[MergeIntoCommand].run(sparkSession)
}

/**
Expand Down
31 changes: 17 additions & 14 deletions spark/src/main/scala/io/delta/tables/DeltaOptimizeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.delta.tables

// scalastyle:off import.ordering.noEmptyLine
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.commands.DeltaOptimizeContext
import org.apache.spark.sql.delta.commands.OptimizeTableCommand
Expand Down Expand Up @@ -81,21 +82,23 @@ class DeltaOptimizeBuilder private(table: DeltaTableV2) extends AnalysisHelper {

private def execute(zOrderBy: Seq[UnresolvedAttribute]): DataFrame = {
val sparkSession = table.spark
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tableIdentifier)
val id = Identifier.of(tableId.database.toArray, tableId.identifier)
val catalogPlugin = sparkSession.sessionState.catalogManager.currentCatalog
val catalog = catalogPlugin match {
case tableCatalog: TableCatalog => tableCatalog
case _ => throw new IllegalArgumentException(
s"Catalog ${catalogPlugin.name} does not support tables")
withActiveSession(sparkSession) {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tableIdentifier)
val id = Identifier.of(tableId.database.toArray, tableId.identifier)
val catalogPlugin = sparkSession.sessionState.catalogManager.currentCatalog
val catalog = catalogPlugin match {
case tableCatalog: TableCatalog => tableCatalog
case _ => throw new IllegalArgumentException(
s"Catalog ${catalogPlugin.name} does not support tables")
}
val resolvedTable = ResolvedTable.create(catalog, id, table)
val optimize = OptimizeTableCommand(
resolvedTable, partitionFilter, DeltaOptimizeContext())(zOrderBy = zOrderBy)
toDataset(sparkSession, optimize)
}
val resolvedTable = ResolvedTable.create(catalog, id, table)
val optimize = OptimizeTableCommand(
resolvedTable, partitionFilter, DeltaOptimizeContext())(zOrderBy = zOrderBy)
toDataset(sparkSession, optimize)
}
}

Expand Down
22 changes: 12 additions & 10 deletions spark/src/main/scala/io/delta/tables/DeltaTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables
import scala.collection.JavaConverters._

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.actions.{Protocol, TableFeatureProtocolUtils}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.commands.AlterTableSetPropertiesDeltaCommand
Expand Down Expand Up @@ -530,15 +531,16 @@ class DeltaTable private[tables](
*
* @since 0.8.0
*/
def upgradeTableProtocol(readerVersion: Int, writerVersion: Int): Unit = {
val alterTableCmd = AlterTableSetPropertiesDeltaCommand(
table,
DeltaConfigs.validateConfigurations(
Map(
"delta.minReaderVersion" -> readerVersion.toString,
"delta.minWriterVersion" -> writerVersion.toString)))
toDataset(sparkSession, alterTableCmd)
}
def upgradeTableProtocol(readerVersion: Int, writerVersion: Int): Unit =
withActiveSession(sparkSession) {
val alterTableCmd = AlterTableSetPropertiesDeltaCommand(
table,
DeltaConfigs.validateConfigurations(
Map(
"delta.minReaderVersion" -> readerVersion.toString,
"delta.minWriterVersion" -> writerVersion.toString)))
toDataset(sparkSession, alterTableCmd)
}

/**
* Modify the protocol to add a supported feature, and if the table does not support table
Expand All @@ -550,7 +552,7 @@ class DeltaTable private[tables](
*
* @since 2.3.0
*/
def addFeatureSupport(featureName: String): Unit = {
def addFeatureSupport(featureName: String): Unit = withActiveSession(sparkSession) {
// Do not check for the correctness of the provided feature name. The ALTER TABLE command will
// do that in a transaction.
val alterTableCmd = AlterTableSetPropertiesDeltaCommand(
Expand Down
3 changes: 2 additions & 1 deletion spark/src/main/scala/io/delta/tables/DeltaTableBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables
import scala.collection.mutable

import org.apache.spark.sql.delta.{DeltaErrors, DeltaTableUtils}
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import io.delta.tables.execution._

Expand Down Expand Up @@ -302,7 +303,7 @@ class DeltaTableBuilder private[tables](
* @since 1.0.0
*/
@Evolving
def execute(): DeltaTable = {
def execute(): DeltaTable = withActiveSession(spark) {
if (identifier == null && location.isEmpty) {
throw DeltaErrors.analysisException("Table name or location has to be specified")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.delta.tables.execution

import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand
import io.delta.tables.DeltaTable

Expand All @@ -28,7 +29,7 @@ trait DeltaConvertBase {
spark: SparkSession,
tableIdentifier: TableIdentifier,
partitionSchema: Option[StructType],
deltaPath: Option[String]): DeltaTable = {
deltaPath: Option[String]): DeltaTable = withActiveSession(spark) {
val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, collectStats = true,
deltaPath)
cvt.run(spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.collection.Map

import org.apache.spark.sql.catalyst.TimeTravel
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.commands.{DeltaGenerateCommand, DescribeDeltaDetailCommand, VacuumCommand}
import org.apache.spark.sql.delta.util.AnalysisHelper
Expand All @@ -39,59 +40,64 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable =>

protected def executeDelete(condition: Option[Expression]): Unit = improveUnsupportedOpError {
val delete = DeleteFromTable(
self.toDF.queryExecution.analyzed,
condition.getOrElse(Literal.TrueLiteral))
toDataset(sparkSession, delete)
withActiveSession(sparkSession) {
val delete = DeleteFromTable(
self.toDF.queryExecution.analyzed,
condition.getOrElse(Literal.TrueLiteral))
toDataset(sparkSession, delete)
}
}

protected def executeHistory(
deltaLog: DeltaLog,
limit: Option[Int] = None,
tableId: Option[TableIdentifier] = None): DataFrame = {
tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) {
val history = deltaLog.history
val spark = self.toDF.sparkSession
spark.createDataFrame(history.getHistory(limit))
sparkSession.createDataFrame(history.getHistory(limit))
}

protected def executeDetails(
path: String,
tableIdentifier: Option[TableIdentifier]): DataFrame = {
tableIdentifier: Option[TableIdentifier]): DataFrame = withActiveSession(sparkSession) {
val details = DescribeDeltaDetailCommand(Option(path), tableIdentifier, self.deltaLog.options)
toDataset(sparkSession, details)
}

protected def executeGenerate(tblIdentifier: String, mode: String): Unit = {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tblIdentifier)
val generate = DeltaGenerateCommand(mode, tableId, self.deltaLog.options)
toDataset(sparkSession, generate)
}
protected def executeGenerate(tblIdentifier: String, mode: String): Unit =
withActiveSession(sparkSession) {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tblIdentifier)
val generate = DeltaGenerateCommand(mode, tableId, self.deltaLog.options)
toDataset(sparkSession, generate)
}

protected def executeUpdate(
set: Map[String, Column],
condition: Option[Column]): Unit = improveUnsupportedOpError {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update = UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
withActiveSession(sparkSession) {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update =
UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
}
}

protected def executeVacuum(
deltaLog: DeltaLog,
retentionHours: Option[Double],
tableId: Option[TableIdentifier] = None): DataFrame = {
tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) {
VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours)
sparkSession.emptyDataFrame
}

protected def executeRestore(
table: DeltaTableV2,
versionAsOf: Option[Long],
timestampAsOf: Option[String]): DataFrame = {
timestampAsOf: Option[String]): DataFrame = withActiveSession(sparkSession) {
val identifier = table.getTableIdentifierIfExists.map(
id => Identifier.of(id.database.toArray, id.table))
val sourceRelation = DataSourceV2Relation.create(table, None, identifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ object DeltaTableUtils extends PredicateHelper
IdentityTransform(FieldReference(Seq(col)))
}

// Workaround for withActive not being visible in io/delta.
def withActiveSession[T](spark: SparkSession)(body: => T): T = spark.withActive(body)

/**
* Uses org.apache.hadoop.fs.Path(Path, String) to concatenate a base path
* and a relative child path and safely handles the case where the base path represents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,35 @@ abstract class MergeIntoSuiteBase
}
}

test("Merge should use the same SparkSession consistently") {
withTempDir { dir =>
withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "false") {
val r = dir.getCanonicalPath
val sourcePath = s"$r/source"
val targetPath = s"$r/target"
val numSourceRecords = 20
spark.range(numSourceRecords)
.withColumn("x", $"id")
.withColumn("y", $"id")
.write.mode("overwrite").format("delta").save(sourcePath)
spark.range(1)
.withColumn("x", $"id")
.write.mode("overwrite").format("delta").save(targetPath)
val spark2 = spark.newSession
spark2.conf.set(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key, "true")
val target = io.delta.tables.DeltaTable.forPath(spark2, targetPath)
val source = spark.read.format("delta").load(sourcePath).alias("s")
val merge = target.alias("t")
.merge(source, "t.id = s.id")
.whenMatched.updateExpr(Map("t.x" -> "t.x + 1"))
.whenNotMatched.insertAll()
.execute()
// The target table should have the same number of rows as the source after the merge
assert(spark.read.format("delta").load(targetPath).count() == numSourceRecords)
}
}
}

// Enable this test in OSS when Spark has the change to report better errors
// when MERGE is not supported.
ignore("Negative case - non-delta target") {
Expand Down

0 comments on commit 683a730

Please sign in to comment.