Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types.{DataType, IntegerType}

/**
* Base class for expressions that are converted to v2 partition transforms.
*
* Subclasses represent abstract transform functions with concrete implementations that are
* determined by data source implementations. Because the concrete implementation is not known,
* these expressions are [[Unevaluable]].
*
* These expressions are used to pass transformations from the DataFrame API:
*
* {{{
* df.writeTo("catalog.db.table").partitionedBy($"category", days($"timestamp")).create()
* }}}
*/
abstract class PartitionTransformExpression extends Expression with Unevaluable {
override def nullable: Boolean = true
}

/**
* Expression for the v2 partition transform years.
*/
case class Years(child: Expression) extends PartitionTransformExpression {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to implement ExpectsInputTypes for these classes for auto analysis checks?

override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform months.
*/
case class Months(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform days.
*/
case class Days(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform hours.
*/
case class Hours(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(child)
}

/**
* Expression for the v2 partition transform bucket.
*/
case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Seq(numBuckets, child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2564,7 +2564,7 @@ class Analyzer(
*/
object ResolveOutputRelation extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case append @ AppendData(table, query, isByName)
case append @ AppendData(table, query, _, isByName)
if table.resolved && query.resolved && !append.outputResolved =>
val projection =
TableOutputResolver.resolveOutputColumns(
Expand All @@ -2576,7 +2576,7 @@ class Analyzer(
append
}

case overwrite @ OverwriteByExpression(table, _, query, isByName)
case overwrite @ OverwriteByExpression(table, _, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
val projection =
TableOutputResolver.resolveOutputColumns(
Expand All @@ -2588,7 +2588,7 @@ class Analyzer(
overwrite
}

case overwrite @ OverwritePartitionsDynamic(table, query, isByName)
case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
val projection =
TableOutputResolver.resolveOutputColumns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ case class ReplaceTableAsSelect(
override def tableSchema: StructType = query.schema
override def children: Seq[LogicalPlan] = Seq(query)

override lazy val resolved: Boolean = {
override lazy val resolved: Boolean = childrenResolved && {
// the table schema is created from the query schema, so the only resolution needed is to check
// that the columns referenced by the table's partitioning exist in the query schema
val references = partitioning.flatMap(_.references).toSet
Expand All @@ -506,15 +506,22 @@ case class ReplaceTableAsSelect(
case class AppendData(
table: NamedRelation,
query: LogicalPlan,
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand

object AppendData {
def byName(table: NamedRelation, df: LogicalPlan): AppendData = {
new AppendData(table, df, isByName = true)
def byName(
table: NamedRelation,
df: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): AppendData = {
new AppendData(table, df, writeOptions, isByName = true)
}

def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = {
new AppendData(table, query, isByName = false)
def byPosition(
table: NamedRelation,
query: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): AppendData = {
new AppendData(table, query, writeOptions, isByName = false)
}
}

Expand All @@ -525,19 +532,26 @@ case class OverwriteByExpression(
table: NamedRelation,
deleteExpr: Expression,
query: LogicalPlan,
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand {
override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved
}

object OverwriteByExpression {
def byName(
table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, df, isByName = true)
table: NamedRelation,
df: LogicalPlan,
deleteExpr: Expression,
writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, df, writeOptions, isByName = true)
}

def byPosition(
table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, query, isByName = false)
table: NamedRelation,
query: LogicalPlan,
deleteExpr: Expression,
writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, query, writeOptions, isByName = false)
}
}

Expand All @@ -547,15 +561,22 @@ object OverwriteByExpression {
case class OverwritePartitionsDynamic(
table: NamedRelation,
query: LogicalPlan,
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand

object OverwritePartitionsDynamic {
def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, df, isByName = true)
def byName(
table: NamedRelation,
df: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, df, writeOptions, isByName = true)
}

def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, query, isByName = false)
def byPosition(
table: NamedRelation,
query: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, query, writeOptions, isByName = false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.connector.catalog.{SupportsDelete, SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

object DataSourceV2Implicits {
implicit class TableHelper(table: Table) {
Expand Down Expand Up @@ -53,4 +56,10 @@ object DataSourceV2Implicits {

def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports)
}

implicit class OptionsHelper(options: Map[String, String]) {
def asOptions: CaseInsensitiveStringMap = {
new CaseInsensitiveStringMap(options.asJava)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ class InMemoryTable(
override val properties: util.Map[String, String])
extends Table with SupportsRead with SupportsWrite with SupportsDelete {

private val allowUnsupportedTransforms =
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean

partitioning.foreach { t =>
if (!t.isInstanceOf[IdentityTransform]) {
if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) {
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
modeForDSV2 match {
case SaveMode.Append =>
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan)
AppendData.byName(relation, df.logicalPlan, extraOptions.toMap)
}

case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
// truncate the table
runCommand(df.sparkSession, "save") {
OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true))
OverwriteByExpression.byName(
relation, df.logicalPlan, Literal(true), extraOptions.toMap)
}

case other =>
Expand Down Expand Up @@ -382,17 +383,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val command = modeForDSV2 match {
case SaveMode.Append =>
AppendData.byPosition(table, df.logicalPlan)
AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap)

case SaveMode.Overwrite =>
val conf = df.sparkSession.sessionState.conf
val dynamicPartitionOverwrite = table.table.partitioning.size > 0 &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byPosition(table, df.logicalPlan)
OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, extraOptions.toMap)
} else {
OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true))
OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap)
}

case other =>
Expand Down
Loading