Skip to content

Commit cef1705

Browse files
committed
save so far
1 parent a665313 commit cef1705

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions.Literal
2222
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
23+
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
2324
import org.apache.spark.sql.connector.catalog.TableCapability._
2425
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
25-
<<<<<<< HEAD
26-
import org.apache.spark.sql.sources.v2.{SupportsWrite, Table}
27-
import org.apache.spark.sql.sources.v2.TableCapability._
28-
=======
29-
>>>>>>> c56a012bc839cd2f92c2be41faea91d1acfba4eb
3026
import org.apache.spark.sql.types.BooleanType
3127

3228
/**
@@ -37,9 +33,8 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {
3733

3834
private def failAnalysis(msg: String): Unit = throw new AnalysisException(msg)
3935

40-
private def supportsBatchWrite(table: Table): Boolean = table match {
41-
case supportsWrite: SupportsWrite => supportsWrite.supportsAny(BATCH_WRITE, V1_BATCH_WRITE)
42-
case _ => false
36+
private def supportsBatchWrite(table: Table): Boolean = {
37+
table.supportsAny(BATCH_WRITE, V1_BATCH_WRITE)
4338
}
4439

4540
override def apply(plan: LogicalPlan): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SQLCo
2828
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
2929
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
3030
import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
31+
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils}
3132
import org.apache.spark.sql.sources._
3233
import org.apache.spark.sql.test.SharedSparkSession
3334
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -114,8 +115,12 @@ private object InMemoryV1Provider {
114115
}
115116
}
116117

117-
class InMemoryV1Provider extends TableProvider with DataSourceRegister {
118+
class InMemoryV1Provider
119+
extends TableProvider
120+
with DataSourceRegister
121+
with CreatableRelationProvider {
118122
override def getTable(options: CaseInsensitiveStringMap): Table = {
123+
119124
InMemoryV1Provider.tables.getOrElse(options.get("name"), {
120125
new InMemoryTableWithV1Fallback(
121126
"InMemoryTableWithV1Fallback",
@@ -127,6 +132,45 @@ class InMemoryV1Provider extends TableProvider with DataSourceRegister {
127132
}
128133

129134
override def shortName(): String = "in-memory"
135+
136+
override def createRelation(
137+
sqlContext: SQLContext,
138+
mode: SaveMode,
139+
parameters: Map[String, String],
140+
data: DataFrame): BaseRelation = {
141+
val _sqlContext = sqlContext
142+
143+
val partitioning = parameters.get(DataSourceUtils.PARTITIONING_COLUMNS_KEY).map { value =>
144+
DataSourceUtils.decodePartitioningColumns(value).map { partitioningColumn =>
145+
146+
}
147+
}
148+
149+
val table = new InMemoryTableWithV1Fallback(
150+
"InMemoryTableWithV1Fallback",
151+
data.schema.asNullable,
152+
Array.empty,
153+
Map.empty[String, String].asJava
154+
)
155+
156+
def getRelation: BaseRelation = new BaseRelation {
157+
override def sqlContext: SQLContext = _sqlContext
158+
override def schema: StructType = table.schema
159+
}
160+
161+
if (mode == SaveMode.ErrorIfExists && dataMap.nonEmpty) {
162+
throw new AnalysisException("Table already exists")
163+
} else if (mode == SaveMode.Ignore && dataMap.nonEmpty) {
164+
// do nothing
165+
return getRelation
166+
}
167+
val writer = new FallbackWriteBuilder(new CaseInsensitiveStringMap(parameters.asJava))
168+
if (mode == SaveMode.Overwrite) {
169+
writer.truncate()
170+
}
171+
writer.buildForV1Write().insert(data, overwrite = false)
172+
getRelation
173+
}
130174
}
131175

132176
class InMemoryTableWithV1Fallback(
@@ -135,8 +179,7 @@ class InMemoryTableWithV1Fallback(
135179
override val partitioning: Array[Transform],
136180
override val properties: util.Map[String, String])
137181
extends Table
138-
with SupportsWrite
139-
with CreatableRelationProvider {
182+
with SupportsWrite {
140183

141184
partitioning.foreach { t =>
142185
if (!t.isInstanceOf[IdentityTransform]) {
@@ -159,24 +202,6 @@ class InMemoryTableWithV1Fallback(
159202
new FallbackWriteBuilder(options)
160203
}
161204

162-
override def createRelation(
163-
sqlContext: SQLContext,
164-
mode: SaveMode,
165-
parameters: Map[String, String],
166-
data: DataFrame): BaseRelation = {
167-
if (mode == SaveMode.ErrorIfExists && dataMap.nonEmpty) {
168-
throw new AnalysisException("Table already exists")
169-
} else if (mode == SaveMode.Ignore && dataMap.nonEmpty) {
170-
// do nothing
171-
} else if (mode == SaveMode.Overwrite) {
172-
val writer = new FallbackWriteBuilder(new CaseInsensitiveStringMap(parameters.asJava))
173-
writer.truncate()
174-
writer.buildForV1Write().insert(data, overwrite = false)
175-
} else {
176-
177-
}
178-
}
179-
180205
private class FallbackWriteBuilder(options: CaseInsensitiveStringMap)
181206
extends WriteBuilder
182207
with V1WriteBuilder

0 commit comments

Comments
 (0)