Skip to content

Commit eb7ee68

Browse files
brkyvzcloud-fan
authored andcommitted
[SPARK-29062][SQL] Add V1_BATCH_WRITE to the TableCapabilityChecks
### What changes were proposed in this pull request? Currently the checks in the Analyzer require that V2 Tables have BATCH_WRITE defined for all tables that have V1 Write fallbacks. This is confusing as these tables may not have the V2 writer interface implemented yet. This PR adds this table capability to these checks. In addition, this allows V2 tables to leverage the V1 APIs for DataFrameWriter.save if they do extend the V1_BATCH_WRITE capability. This way, these tables can continue to receive partitioning information and also perform checks for the existence of tables, and support all SaveModes. ### Why are the changes needed? Partitioned saves through DataFrame.write are otherwise broken for V2 tables that support the V1 write API. ### Does this PR introduce any user-facing change? No ### How was this patch tested? V1WriteFallbackSuite Closes #25767 from brkyvz/bwcheck. Authored-by: Burak Yavuz <brkyvz@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ec8a1a8 commit eb7ee68

File tree

4 files changed

+153
-28
lines changed

4 files changed

+153
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
253253

254254
val maybeV2Provider = lookupV2Provider()
255255
if (maybeV2Provider.isDefined) {
256-
if (partitioningColumns.nonEmpty) {
257-
throw new AnalysisException(
258-
"Cannot write data to TableProvider implementation if partition columns are specified.")
259-
}
260-
261256
val provider = maybeV2Provider.get
262257
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
263258
provider, df.sparkSession.sessionState.conf)
@@ -267,6 +262,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
267262
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
268263
provider.getTable(dsOptions) match {
269264
case table: SupportsWrite if table.supports(BATCH_WRITE) =>
265+
if (partitioningColumns.nonEmpty) {
266+
throw new AnalysisException("Cannot write data to TableProvider implementation " +
267+
"if partition columns are specified.")
268+
}
270269
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
271270
modeForDSV2 match {
272271
case SaveMode.Append =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions.Literal
2222
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
23+
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
2324
import org.apache.spark.sql.connector.catalog.TableCapability._
2425
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
2526
import org.apache.spark.sql.types.BooleanType
@@ -32,6 +33,10 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {
3233

3334
private def failAnalysis(msg: String): Unit = throw new AnalysisException(msg)
3435

36+
private def supportsBatchWrite(table: Table): Boolean = {
37+
table.supportsAny(BATCH_WRITE, V1_BATCH_WRITE)
38+
}
39+
3540
override def apply(plan: LogicalPlan): Unit = {
3641
plan foreach {
3742
case r: DataSourceV2Relation if !r.table.supports(BATCH_READ) =>
@@ -43,8 +48,7 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {
4348

4449
// TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a
4550
// a logical plan for streaming write.
46-
47-
case AppendData(r: DataSourceV2Relation, _, _, _) if !r.table.supports(BATCH_WRITE) =>
51+
case AppendData(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) =>
4852
failAnalysis(s"Table ${r.table.name()} does not support append in batch mode.")
4953

5054
case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _)
@@ -54,13 +58,13 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {
5458
case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _) =>
5559
expr match {
5660
case Literal(true, BooleanType) =>
57-
if (!r.table.supports(BATCH_WRITE) ||
58-
!r.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) {
61+
if (!supportsBatchWrite(r.table) ||
62+
!r.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) {
5963
failAnalysis(
6064
s"Table ${r.table.name()} does not support truncate in batch mode.")
6165
}
6266
case _ =>
63-
if (!r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_BY_FILTER)) {
67+
if (!supportsBatchWrite(r.table) || !r.table.supports(OVERWRITE_BY_FILTER)) {
6468
failAnalysis(s"Table ${r.table.name()} does not support " +
6569
"overwrite by filter in batch mode.")
6670
}

sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,19 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {
9898
}
9999

100100
test("AppendData: check correct capabilities") {
101-
val plan = AppendData.byName(
102-
DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty),
103-
TestRelation)
101+
Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write =>
102+
val plan = AppendData.byName(
103+
DataSourceV2Relation.create(CapabilityTable(write), CaseInsensitiveStringMap.empty),
104+
TestRelation)
104105

105-
TableCapabilityCheck.apply(plan)
106+
TableCapabilityCheck.apply(plan)
107+
}
106108
}
107109

108110
test("Truncate: check missing capabilities") {
109111
Seq(CapabilityTable(),
110112
CapabilityTable(BATCH_WRITE),
113+
CapabilityTable(V1_BATCH_WRITE),
111114
CapabilityTable(TRUNCATE),
112115
CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table =>
113116

@@ -125,7 +128,9 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {
125128

126129
test("Truncate: check correct capabilities") {
127130
Seq(CapabilityTable(BATCH_WRITE, TRUNCATE),
128-
CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table =>
131+
CapabilityTable(V1_BATCH_WRITE, TRUNCATE),
132+
CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER),
133+
CapabilityTable(V1_BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table =>
129134

130135
val plan = OverwriteByExpression.byName(
131136
DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
@@ -137,6 +142,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {
137142

138143
test("OverwriteByExpression: check missing capabilities") {
139144
Seq(CapabilityTable(),
145+
CapabilityTable(V1_BATCH_WRITE),
140146
CapabilityTable(BATCH_WRITE),
141147
CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table =>
142148

@@ -153,12 +159,14 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {
153159
}
154160

155161
test("OverwriteByExpression: check correct capabilities") {
156-
val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)
157-
val plan = OverwriteByExpression.byName(
158-
DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
159-
EqualTo(AttributeReference("x", LongType)(), Literal(5)))
162+
Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write =>
163+
val table = CapabilityTable(write, OVERWRITE_BY_FILTER)
164+
val plan = OverwriteByExpression.byName(
165+
DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
166+
EqualTo(AttributeReference("x", LongType)(), Literal(5)))
160167

161-
TableCapabilityCheck.apply(plan)
168+
TableCapabilityCheck.apply(plan)
169+
}
162170
}
163171

164172
test("OverwritePartitionsDynamic: check missing capabilities") {

sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ import scala.collection.mutable
2424

2525
import org.scalatest.BeforeAndAfter
2626

27-
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession}
27+
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
2828
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
2929
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
3030
import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
31-
import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation}
31+
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils}
32+
import org.apache.spark.sql.sources._
3233
import org.apache.spark.sql.test.SharedSparkSession
3334
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
3435
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -52,7 +53,11 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
5253
test("append fallback") {
5354
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
5455
df.write.mode("append").option("name", "t1").format(v2Format).save()
56+
5557
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
58+
assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable)
59+
assert(InMemoryV1Provider.tables("t1").partitioning.isEmpty)
60+
5661
df.write.mode("append").option("name", "t1").format(v2Format).save()
5762
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df.union(df))
5863
}
@@ -65,6 +70,59 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
6570
df2.write.mode("overwrite").option("name", "t1").format(v2Format).save()
6671
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df2)
6772
}
73+
74+
SaveMode.values().foreach { mode =>
75+
test(s"save: new table creations with partitioning for table - mode: $mode") {
76+
val format = classOf[InMemoryV1Provider].getName
77+
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
78+
df.write.mode(mode).option("name", "t1").format(format).partitionBy("a").save()
79+
80+
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
81+
assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable)
82+
assert(InMemoryV1Provider.tables("t1").partitioning.sameElements(
83+
Array(IdentityTransform(FieldReference(Seq("a"))))))
84+
}
85+
}
86+
87+
test("save: default mode is ErrorIfExists") {
88+
val format = classOf[InMemoryV1Provider].getName
89+
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
90+
91+
df.write.option("name", "t1").format(format).partitionBy("a").save()
92+
// default is ErrorIfExists, and since a table already exists we throw an exception
93+
val e = intercept[AnalysisException] {
94+
df.write.option("name", "t1").format(format).partitionBy("a").save()
95+
}
96+
assert(e.getMessage.contains("already exists"))
97+
}
98+
99+
test("save: Ignore mode") {
100+
val format = classOf[InMemoryV1Provider].getName
101+
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
102+
103+
df.write.option("name", "t1").format(format).partitionBy("a").save()
104+
// no-op
105+
df.write.option("name", "t1").format(format).mode("ignore").partitionBy("a").save()
106+
107+
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
108+
}
109+
110+
test("save: tables can perform schema and partitioning checks if they already exist") {
111+
val format = classOf[InMemoryV1Provider].getName
112+
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
113+
114+
df.write.option("name", "t1").format(format).partitionBy("a").save()
115+
val e2 = intercept[IllegalArgumentException] {
116+
df.write.mode("append").option("name", "t1").format(format).partitionBy("b").save()
117+
}
118+
assert(e2.getMessage.contains("partitioning"))
119+
120+
val e3 = intercept[IllegalArgumentException] {
121+
Seq((1, "x")).toDF("c", "d").write.mode("append").option("name", "t1").format(format)
122+
.save()
123+
}
124+
assert(e3.getMessage.contains("schema"))
125+
}
68126
}
69127

70128
class V1WriteFallbackSessionCatalogSuite
@@ -114,26 +172,83 @@ private object InMemoryV1Provider {
114172
}
115173
}
116174

117-
class InMemoryV1Provider extends TableProvider with DataSourceRegister {
175+
class InMemoryV1Provider
176+
extends TableProvider
177+
with DataSourceRegister
178+
with CreatableRelationProvider {
118179
override def getTable(options: CaseInsensitiveStringMap): Table = {
119-
InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), {
180+
181+
InMemoryV1Provider.tables.getOrElse(options.get("name"), {
120182
new InMemoryTableWithV1Fallback(
121183
"InMemoryTableWithV1Fallback",
122-
new StructType().add("a", IntegerType).add("b", StringType),
123-
Array(IdentityTransform(FieldReference(Seq("a")))),
184+
new StructType(),
185+
Array.empty,
124186
options.asCaseSensitiveMap()
125187
)
126188
})
127189
}
128190

129191
override def shortName(): String = "in-memory"
192+
193+
override def createRelation(
194+
sqlContext: SQLContext,
195+
mode: SaveMode,
196+
parameters: Map[String, String],
197+
data: DataFrame): BaseRelation = {
198+
val _sqlContext = sqlContext
199+
200+
val partitioning = parameters.get(DataSourceUtils.PARTITIONING_COLUMNS_KEY).map { value =>
201+
DataSourceUtils.decodePartitioningColumns(value).map { partitioningColumn =>
202+
IdentityTransform(FieldReference(partitioningColumn))
203+
}
204+
}.getOrElse(Nil)
205+
206+
val tableName = parameters("name")
207+
val tableOpt = InMemoryV1Provider.tables.get(tableName)
208+
val table = tableOpt.getOrElse(new InMemoryTableWithV1Fallback(
209+
"InMemoryTableWithV1Fallback",
210+
data.schema.asNullable,
211+
partitioning.toArray,
212+
Map.empty[String, String].asJava
213+
))
214+
if (tableOpt.isEmpty) {
215+
InMemoryV1Provider.tables.put(tableName, table)
216+
} else {
217+
if (data.schema.asNullable != table.schema) {
218+
throw new IllegalArgumentException("Wrong schema provided")
219+
}
220+
if (!partitioning.sameElements(table.partitioning)) {
221+
throw new IllegalArgumentException("Wrong partitioning provided")
222+
}
223+
}
224+
225+
def getRelation: BaseRelation = new BaseRelation {
226+
override def sqlContext: SQLContext = _sqlContext
227+
override def schema: StructType = table.schema
228+
}
229+
230+
if (mode == SaveMode.ErrorIfExists && tableOpt.isDefined) {
231+
throw new AnalysisException("Table already exists")
232+
} else if (mode == SaveMode.Ignore && tableOpt.isDefined) {
233+
// do nothing
234+
return getRelation
235+
}
236+
val writer = table.newWriteBuilder(new CaseInsensitiveStringMap(parameters.asJava))
237+
if (mode == SaveMode.Overwrite) {
238+
writer.asInstanceOf[SupportsTruncate].truncate()
239+
}
240+
writer.asInstanceOf[V1WriteBuilder].buildForV1Write().insert(data, overwrite = false)
241+
getRelation
242+
}
130243
}
131244

132245
class InMemoryTableWithV1Fallback(
133246
override val name: String,
134247
override val schema: StructType,
135248
override val partitioning: Array[Transform],
136-
override val properties: util.Map[String, String]) extends Table with SupportsWrite {
249+
override val properties: util.Map[String, String])
250+
extends Table
251+
with SupportsWrite {
137252

138253
partitioning.foreach { t =>
139254
if (!t.isInstanceOf[IdentityTransform]) {
@@ -142,7 +257,6 @@ class InMemoryTableWithV1Fallback(
142257
}
143258

144259
override def capabilities: util.Set[TableCapability] = Set(
145-
TableCapability.BATCH_WRITE,
146260
TableCapability.V1_BATCH_WRITE,
147261
TableCapability.OVERWRITE_BY_FILTER,
148262
TableCapability.TRUNCATE).asJava

0 commit comments

Comments
 (0)