@@ -24,11 +24,12 @@ import scala.collection.mutable
2424
2525import org .scalatest .BeforeAndAfter
2626
27- import org .apache .spark .sql .{DataFrame , QueryTest , Row , SaveMode , SparkSession }
27+ import org .apache .spark .sql .{AnalysisException , DataFrame , QueryTest , Row , SaveMode , SparkSession , SQLContext }
2828import org .apache .spark .sql .connector .catalog .{SupportsWrite , Table , TableCapability , TableProvider }
2929import org .apache .spark .sql .connector .expressions .{FieldReference , IdentityTransform , Transform }
3030import org .apache .spark .sql .connector .write .{SupportsOverwrite , SupportsTruncate , V1WriteBuilder , WriteBuilder }
31- import org .apache .spark .sql .sources .{DataSourceRegister , Filter , InsertableRelation }
31+ import org .apache .spark .sql .execution .datasources .{DataSource , DataSourceUtils }
32+ import org .apache .spark .sql .sources ._
3233import org .apache .spark .sql .test .SharedSparkSession
3334import org .apache .spark .sql .types .{IntegerType , StringType , StructType }
3435import org .apache .spark .sql .util .CaseInsensitiveStringMap
@@ -52,7 +53,11 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
5253 test(" append fallback" ) {
5354 val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
5455 df.write.mode(" append" ).option(" name" , " t1" ).format(v2Format).save()
56+
5557 checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df)
58+ assert(InMemoryV1Provider .tables(" t1" ).schema === df.schema.asNullable)
59+ assert(InMemoryV1Provider .tables(" t1" ).partitioning.isEmpty)
60+
5661 df.write.mode(" append" ).option(" name" , " t1" ).format(v2Format).save()
5762 checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df.union(df))
5863 }
@@ -65,6 +70,59 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
6570 df2.write.mode(" overwrite" ).option(" name" , " t1" ).format(v2Format).save()
6671 checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df2)
6772 }
73+
74+ SaveMode .values().foreach { mode =>
75+ test(s " save: new table creations with partitioning for table - mode: $mode" ) {
76+ val format = classOf [InMemoryV1Provider ].getName
77+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
78+ df.write.mode(mode).option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
79+
80+ checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df)
81+ assert(InMemoryV1Provider .tables(" t1" ).schema === df.schema.asNullable)
82+ assert(InMemoryV1Provider .tables(" t1" ).partitioning.sameElements(
83+ Array (IdentityTransform (FieldReference (Seq (" a" ))))))
84+ }
85+ }
86+
87+ test(" save: default mode is ErrorIfExists" ) {
88+ val format = classOf [InMemoryV1Provider ].getName
89+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
90+
91+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
92+ // default is ErrorIfExists, and since a table already exists we throw an exception
93+ val e = intercept[AnalysisException ] {
94+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
95+ }
96+ assert(e.getMessage.contains(" already exists" ))
97+ }
98+
99+ test(" save: Ignore mode" ) {
100+ val format = classOf [InMemoryV1Provider ].getName
101+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
102+
103+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
104+ // no-op
105+ df.write.option(" name" , " t1" ).format(format).mode(" ignore" ).partitionBy(" a" ).save()
106+
107+ checkAnswer(InMemoryV1Provider .getTableData(spark, " t1" ), df)
108+ }
109+
110+ test(" save: tables can perform schema and partitioning checks if they already exist" ) {
111+ val format = classOf [InMemoryV1Provider ].getName
112+ val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
113+
114+ df.write.option(" name" , " t1" ).format(format).partitionBy(" a" ).save()
115+ val e2 = intercept[IllegalArgumentException ] {
116+ df.write.mode(" append" ).option(" name" , " t1" ).format(format).partitionBy(" b" ).save()
117+ }
118+ assert(e2.getMessage.contains(" partitioning" ))
119+
120+ val e3 = intercept[IllegalArgumentException ] {
121+ Seq ((1 , " x" )).toDF(" c" , " d" ).write.mode(" append" ).option(" name" , " t1" ).format(format)
122+ .save()
123+ }
124+ assert(e3.getMessage.contains(" schema" ))
125+ }
68126}
69127
70128class V1WriteFallbackSessionCatalogSuite
@@ -114,26 +172,83 @@ private object InMemoryV1Provider {
114172 }
115173}
116174
117- class InMemoryV1Provider extends TableProvider with DataSourceRegister {
175+ class InMemoryV1Provider
176+ extends TableProvider
177+ with DataSourceRegister
178+ with CreatableRelationProvider {
118179 override def getTable (options : CaseInsensitiveStringMap ): Table = {
119- InMemoryV1Provider .tables.getOrElseUpdate(options.get(" name" ), {
180+
181+ InMemoryV1Provider .tables.getOrElse(options.get(" name" ), {
120182 new InMemoryTableWithV1Fallback (
121183 " InMemoryTableWithV1Fallback" ,
122- new StructType ().add( " a " , IntegerType ).add( " b " , StringType ) ,
123- Array ( IdentityTransform ( FieldReference ( Seq ( " a " )))) ,
184+ new StructType (),
185+ Array .empty ,
124186 options.asCaseSensitiveMap()
125187 )
126188 })
127189 }
128190
129191 override def shortName (): String = " in-memory"
192+
193+ override def createRelation (
194+ sqlContext : SQLContext ,
195+ mode : SaveMode ,
196+ parameters : Map [String , String ],
197+ data : DataFrame ): BaseRelation = {
198+ val _sqlContext = sqlContext
199+
200+ val partitioning = parameters.get(DataSourceUtils .PARTITIONING_COLUMNS_KEY ).map { value =>
201+ DataSourceUtils .decodePartitioningColumns(value).map { partitioningColumn =>
202+ IdentityTransform (FieldReference (partitioningColumn))
203+ }
204+ }.getOrElse(Nil )
205+
206+ val tableName = parameters(" name" )
207+ val tableOpt = InMemoryV1Provider .tables.get(tableName)
208+ val table = tableOpt.getOrElse(new InMemoryTableWithV1Fallback (
209+ " InMemoryTableWithV1Fallback" ,
210+ data.schema.asNullable,
211+ partitioning.toArray,
212+ Map .empty[String , String ].asJava
213+ ))
214+ if (tableOpt.isEmpty) {
215+ InMemoryV1Provider .tables.put(tableName, table)
216+ } else {
217+ if (data.schema.asNullable != table.schema) {
218+ throw new IllegalArgumentException (" Wrong schema provided" )
219+ }
220+ if (! partitioning.sameElements(table.partitioning)) {
221+ throw new IllegalArgumentException (" Wrong partitioning provided" )
222+ }
223+ }
224+
225+ def getRelation : BaseRelation = new BaseRelation {
226+ override def sqlContext : SQLContext = _sqlContext
227+ override def schema : StructType = table.schema
228+ }
229+
230+ if (mode == SaveMode .ErrorIfExists && tableOpt.isDefined) {
231+ throw new AnalysisException (" Table already exists" )
232+ } else if (mode == SaveMode .Ignore && tableOpt.isDefined) {
233+ // do nothing
234+ return getRelation
235+ }
236+ val writer = table.newWriteBuilder(new CaseInsensitiveStringMap (parameters.asJava))
237+ if (mode == SaveMode .Overwrite ) {
238+ writer.asInstanceOf [SupportsTruncate ].truncate()
239+ }
240+ writer.asInstanceOf [V1WriteBuilder ].buildForV1Write().insert(data, overwrite = false )
241+ getRelation
242+ }
130243}
131244
132245class InMemoryTableWithV1Fallback (
133246 override val name : String ,
134247 override val schema : StructType ,
135248 override val partitioning : Array [Transform ],
136- override val properties : util.Map [String , String ]) extends Table with SupportsWrite {
249+ override val properties : util.Map [String , String ])
250+ extends Table
251+ with SupportsWrite {
137252
138253 partitioning.foreach { t =>
139254 if (! t.isInstanceOf [IdentityTransform ]) {
@@ -142,7 +257,6 @@ class InMemoryTableWithV1Fallback(
142257 }
143258
144259 override def capabilities : util.Set [TableCapability ] = Set (
145- TableCapability .BATCH_WRITE ,
146260 TableCapability .V1_BATCH_WRITE ,
147261 TableCapability .OVERWRITE_BY_FILTER ,
148262 TableCapability .TRUNCATE ).asJava
0 commit comments