Skip to content

[SPARK-18087] [SQL] Optimize insert to not require REPAIR TABLE #15633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ case class DataSource(
columns,
bucketSpec,
format,
() => Unit, // No existing table needs to be refreshed.
_ => Unit, // No existing table needs to be refreshed.
options,
data.logicalPlan,
mode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, Inte
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
Expand All @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec}
import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -179,24 +180,30 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
"Cannot overwrite a path that is also being read from.")
}

def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
if (l.catalogTable.isDefined &&
Copy link
Contributor

@cloud-fan cloud-fan Oct 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we move this if out of the function? e.g.

val refreshPartitionsCallback = if (...) {
  ...
} else {
  _ => ()
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo that is a little harder to read, since you have two anonymous function declarations instead of one.

l.catalogTable.get.partitionColumnNames.nonEmpty &&
l.catalogTable.get.partitionProviderIsHive) {
val metastoreUpdater = AlterTableAddPartitionCommand(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we just copy the main logic of AlterTableAddPartitionCommand here? or we have to fetch the table metadata from metastore everytime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather keep it, since the fetch overhead is pretty small

l.catalogTable.get.identifier,
updatedPartitions.map(p => (p, None)),
ifNotExists = true)
metastoreUpdater.run(t.sparkSession)
}
t.location.refresh()
}

val insertCmd = InsertIntoHadoopFsRelationCommand(
outputPath,
query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver),
t.bucketSpec,
t.fileFormat,
() => t.location.refresh(),
refreshPartitionsCallback,
t.options,
query,
mode)

if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
l.catalogTable.get.partitionProviderIsHive) {
// TODO(ekl) we should be more efficient here and only recover the newly added partitions
val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier)
Union(insertCmd, recoverPartitionCmd)
} else {
insertCmd
}
insertCmd
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
Expand All @@ -40,7 +41,7 @@ case class InsertIntoHadoopFsRelationCommand(
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
refreshFunction: () => Unit,
refreshFunction: (Seq[TablePartitionSpec]) => Unit,
options: Map[String, String],
@transient query: LogicalPlan,
mode: SaveMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.hadoop.util.Shell
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -244,6 +245,17 @@ object PartitioningUtils {
}
}

/**
* Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
* for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`.
*/
def parsePathFragment(pathFragment: String): TablePartitionSpec = {
pathFragment.split("/").map { kv =>
val pair = kv.split("=", 2)
(unescapePathName(pair(0)), unescapePathName(pair(1)))
}.toMap
}

/**
* Normalize the column names in partition specification, w.r.t. the real partition column names
* and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources

import java.util.{Date, UUID}

import scala.collection.mutable

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
Expand All @@ -30,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -85,7 +88,7 @@ object WriteOutput extends Logging {
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
refreshFunction: () => Unit,
refreshFunction: (Seq[TablePartitionSpec]) => Unit,
options: Map[String, String],
isAppend: Boolean): Unit = {

Expand Down Expand Up @@ -120,19 +123,19 @@ object WriteOutput extends Logging {
val committer = setupDriverCommitter(job, outputPath.toString, isAppend)

try {
sparkSession.sparkContext.runJob(queryExecution.toRdd,
val updatedPartitions = sparkSession.sparkContext.runJob(queryExecution.toRdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
sparkStageId = taskContext.stageId(),
sparkPartitionId = taskContext.partitionId(),
sparkAttemptNumber = taskContext.attemptNumber(),
iterator = iter)
})
}).flatten.distinct

committer.commitJob(job)
logInfo(s"Job ${job.getJobID} committed.")
refreshFunction()
refreshFunction(updatedPartitions.map(PartitioningUtils.parsePathFragment))
} catch { case cause: Throwable =>
logError(s"Aborting job ${job.getJobID}.", cause)
committer.abortJob(job, JobStatus.State.FAILED)
Expand All @@ -147,7 +150,7 @@ object WriteOutput extends Logging {
sparkStageId: Int,
sparkPartitionId: Int,
sparkAttemptNumber: Int,
iterator: Iterator[InternalRow]): Unit = {
iterator: Iterator[InternalRow]): Set[String] = {

val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
Expand Down Expand Up @@ -187,11 +190,12 @@ object WriteOutput extends Logging {
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Execute the task to write rows out
writeTask.execute(iterator)
val outputPaths = writeTask.execute(iterator)
writeTask.releaseResources()

// Commit the task
SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId)
outputPaths
})(catchBlock = {
// If there is an error, release resource and then abort the task
try {
Expand All @@ -213,7 +217,7 @@ object WriteOutput extends Logging {
* automatically trigger task aborts.
*/
private trait ExecuteWriteTask {
def execute(iterator: Iterator[InternalRow]): Unit
def execute(iterator: Iterator[InternalRow]): Set[String]
def releaseResources(): Unit

final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = {
Expand All @@ -240,11 +244,12 @@ object WriteOutput extends Logging {
outputWriter
}

override def execute(iter: Iterator[InternalRow]): Unit = {
override def execute(iter: Iterator[InternalRow]): Set[String] = {
while (iter.hasNext) {
val internalRow = iter.next()
outputWriter.writeInternal(internalRow)
}
Set.empty
}

override def releaseResources(): Unit = {
Expand Down Expand Up @@ -327,7 +332,7 @@ object WriteOutput extends Logging {
newWriter
}

override def execute(iter: Iterator[InternalRow]): Unit = {
override def execute(iter: Iterator[InternalRow]): Set[String] = {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] =
description.partitionColumns ++ bucketIdExpression ++ sortColumns
Expand Down Expand Up @@ -375,6 +380,7 @@ object WriteOutput extends Logging {

// If anything below fails, we should abort the task.
var currentKey: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
Expand All @@ -386,13 +392,18 @@ object WriteOutput extends Logging {
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
val partitionPath = getPartitionString(currentKey).getString(0)
if (partitionPath.nonEmpty) {
updatedPartitions.add(partitionPath)
}
}
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
updatedPartitions.toSet
}

override def releaseResources(): Unit = {
Expand Down