@@ -24,7 +24,7 @@ import scala.collection.mutable
2424
2525import org .scalatest .BeforeAndAfter
2626
27- import org .apache .spark .sql .{AnalysisException , DataFrame , QueryTest , Row , SQLContext , SaveMode , SparkSession }
27+ import org .apache .spark .sql .{AnalysisException , DataFrame , QueryTest , Row , SaveMode , SparkSession , SQLContext }
2828import org .apache .spark .sql .connector .catalog .{SupportsWrite , Table , TableCapability , TableProvider }
2929import org .apache .spark .sql .connector .expressions .{FieldReference , IdentityTransform , Transform }
3030import org .apache .spark .sql .connector .write .{SupportsOverwrite , SupportsTruncate , V1WriteBuilder , WriteBuilder }
@@ -53,7 +53,11 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
5353 test(" append fallback" ) {
5454 val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
5555 df.write.mode(" append" ).option(" name" , " t1" ).format(v2Format).save()
56+
5657 checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df)
58+ assert(InMemoryV1Provider .tables(" t1" ).schema === df.schema.asNullable)
59+ assert(InMemoryV1Provider .tables(" t1" ).partitioning.isEmpty)
60+
5761 df.write.mode(" append" ).option(" name" , " t1" ).format(v2Format).save()
5862 checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df.union(df))
5963 }
@@ -66,6 +70,59 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
6670 df2.write.mode(" overwrite" ).option(" name" , " t1" ).format(v2Format).save()
6771 checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df2)
6872 }
73+
74+ SaveMode .values().foreach { mode =>
75+ test(s " save: new table creations with partitioning for table - mode: $mode" ) {
76+ val format = classOf [InMemoryV1Provider ].getName
77+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
78+ df.write.mode(mode).option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
79+
80+ checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df)
81+ assert(InMemoryV1Provider .tables(" t1" ).schema === df.schema.asNullable)
82+ assert(InMemoryV1Provider .tables(" t1" ).partitioning.sameElements(
83+ Array (IdentityTransform (FieldReference (Seq (" a" ))))))
84+ }
85+ }
86+
87+ test(" save: default mode is ErrorIfExists" ) {
88+ val format = classOf [InMemoryV1Provider ].getName
89+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
90+
91+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
92+ // default is ErrorIfExists, and since a table already exists we throw an exception
93+ val e = intercept[AnalysisException ] {
94+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
95+ }
96+ assert(e.getMessage.contains(" already exists" ))
97+ }
98+
99+ test(" save: Ignore mode" ) {
100+ val format = classOf [InMemoryV1Provider ].getName
101+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
102+
103+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
104+ // no-op
105+ df.write.option(" name" , " t1" ).format(format).mode(" ignore" ).partitionBy(" a" ).save()
106+
107+ checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df)
108+ }
109+
110+ test(" save: tables can perform schema and partitioning checks if they already exist" ) {
111+ val format = classOf [InMemoryV1Provider ].getName
112+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
113+
114+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
115+ val e2 = intercept[IllegalArgumentException ] {
116+ df.write.mode(" append" ).option(" name" , " t1" ).format(format).partitionBy(" b" ).save()
117+ }
118+ assert(e2.getMessage.contains(" partitioning" ))
119+
120+ val e3 = intercept[IllegalArgumentException ] {
121+ Seq ((1 , " x" )).toDF(" c" , " d" ).write.mode(" append" ).option(" name" , " t1" ).format(format)
122+ .save()
123+ }
124+ assert(e3.getMessage.contains(" schema" ))
125+ }
69126}
70127
71128class V1WriteFallbackSessionCatalogSuite
@@ -142,33 +199,45 @@ class InMemoryV1Provider
142199
143200 val partitioning = parameters.get(DataSourceUtils .PARTITIONING_COLUMNS_KEY ).map { value =>
144201 DataSourceUtils .decodePartitioningColumns(value).map { partitioningColumn =>
145-
202+ IdentityTransform ( FieldReference (partitioningColumn))
146203 }
147- }
204+ }.getOrElse( Nil )
148205
149- val table = new InMemoryTableWithV1Fallback (
206+ val tableName = parameters(" name" )
207+ val tableOpt = InMemoryV1Provider .tables.get(tableName)
208+ val table = tableOpt.getOrElse(new InMemoryTableWithV1Fallback (
150209 " InMemoryTableWithV1Fallback" ,
151210 data.schema.asNullable,
152- Array .empty ,
211+ partitioning.toArray ,
153212 Map .empty[String , String ].asJava
154- )
213+ ))
214+ if (tableOpt.isEmpty) {
215+ InMemoryV1Provider .tables.put(tableName, table)
216+ } else {
217+ if (data.schema.asNullable != table.schema) {
218+ throw new IllegalArgumentException (" Wrong schema provided" )
219+ }
220+ if (! partitioning.sameElements(table.partitioning)) {
221+ throw new IllegalArgumentException (" Wrong partitioning provided" )
222+ }
223+ }
155224
156225 def getRelation : BaseRelation = new BaseRelation {
157226 override def sqlContext : SQLContext = _sqlContext
158227 override def schema : StructType = table.schema
159228 }
160229
161- if (mode == SaveMode .ErrorIfExists && dataMap.nonEmpty ) {
230+ if (mode == SaveMode .ErrorIfExists && tableOpt.isDefined ) {
162231 throw new AnalysisException (" Table already exists" )
163- } else if (mode == SaveMode .Ignore && dataMap.nonEmpty ) {
232+ } else if (mode == SaveMode .Ignore && tableOpt.isDefined ) {
164233 // do nothing
165234 return getRelation
166235 }
167- val writer = new FallbackWriteBuilder (new CaseInsensitiveStringMap (parameters.asJava))
236+ val writer = table.newWriteBuilder (new CaseInsensitiveStringMap (parameters.asJava))
168237 if (mode == SaveMode .Overwrite ) {
169- writer.truncate()
238+ writer.asInstanceOf [ SupportsTruncate ]. truncate()
170239 }
171- writer.buildForV1Write().insert(data, overwrite = false )
240+ writer.asInstanceOf [ V1WriteBuilder ]. buildForV1Write().insert(data, overwrite = false )
172241 getRelation
173242 }
174243}
0 commit comments