diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ba8e4d69ba755..d21b9d9833e9e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} +import org.apache.spark.storage.BlockManagerId /** * :: DeveloperApi :: @@ -95,6 +96,20 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( shuffleId, this) + /** + * Stores the location of the list of chosen external shuffle services for handling the + * shuffle merge requests from mappers in this shuffle map stage. + */ + private[spark] var mergerLocs: Seq[BlockManagerId] = Nil + + def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = { + if (mergerLocs != null) { + this.mergerLocs = mergerLocs + } + } + + def getMergerLocs: Seq[BlockManagerId] = mergerLocs + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4bc49514fc5ad..b38d0e5c617b9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1945,4 +1945,51 @@ package object config { .version("3.0.1") .booleanConf .createWithDefault(false) + + private[spark] val PUSH_BASED_SHUFFLE_ENABLED = + ConfigBuilder("spark.shuffle.push.enabled") + .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " + + "conjunction with the server side flag spark.shuffle.server.mergedShuffleFileManagerImpl " + + "which needs to be set with the appropriate " + + "org.apache.spark.network.shuffle.MergedShuffleFileManager implementation for push-based " + + "shuffle to be enabled") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS = + ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations") + .doc("Maximum number of shuffle push merger locations cached for push based shuffle. " + + "Currently, shuffle push merger locations are nothing but external shuffle services " + + "which are responsible for handling pushed blocks and merging them and serving " + + "merged blocks for later shuffle fetch.") + .version("3.1.0") + .intConf + .createWithDefault(500) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO = + ConfigBuilder("spark.shuffle.push.mergersMinThresholdRatio") + .doc("The minimum number of shuffle merger locations required to enable push based " + + "shuffle for a stage. This is specified as a ratio of the number of partitions in " + + "the child stage. For example, a reduce stage which has 100 partitions and uses the " + + "default value 0.05 requires at least 5 unique merger locations to enable push based " + + "shuffle. Merger locations are currently defined as external shuffle services.") + .version("3.1.0") + .doubleConf + .createWithDefault(0.05) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD = + ConfigBuilder("spark.shuffle.push.mergersMinStaticThreshold") + .doc(s"The static threshold for number of shuffle push merger locations should be " + + "available in order to enable push based shuffle for a stage. Note this config " + + s"works in conjunction with ${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key}. " + + "Maximum of spark.shuffle.push.mergersMinStaticThreshold and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} ratio number of mergers needed to " + + "enable push based shuffle for a stage. For eg: with 1000 partitions for the child " + + "stage with spark.shuffle.push.mergersMinStaticThreshold as 5 and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " + + "at least 50 mergers to enable push based shuffle for that stage.") + .version("3.1.0") + .doubleConf + .createWithDefault(5) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 13b766e654832..6fb0fb93f253b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -249,6 +249,8 @@ private[spark] class DAGScheduler( private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf) + /** * Called by the TaskSetManager to report task's starting. */ @@ -1252,6 +1254,33 @@ private[spark] class DAGScheduler( execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores)) } + /** + * If push based shuffle is enabled, set the shuffle services to be used for the given + * shuffle map stage for block push/merge. + * + * Even with dynamic resource allocation kicking in and significantly reducing the number + * of available active executors, we would still be able to get sufficient shuffle service + * locations for block push/merge by getting the historical locations of past executors. + */ + private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { + // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize + // TODO changes we cannot disable shuffle merge for the retry/reuse cases + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + + logDebug("List of shuffle push merger locations " + + s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + } else { + logInfo("No available merger locations." + + s" Push-based shuffle disabled for $stage (${stage.name})") + } + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") @@ -1281,6 +1310,12 @@ private[spark] class DAGScheduler( stage match { case s: ShuffleMapStage => outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + // Only generate merger location for a given shuffle dependency once. This way, even if + // this stage gets retried, it would still be merging blocks using the same set of + // shuffle services. + if (pushBasedShuffleEnabled) { + prepareShuffleServicesForShuffleMapStage(s) + } case s: ResultStage => outputCommitCoordinator.stageStart( stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) @@ -2027,6 +2062,11 @@ private[spark] class DAGScheduler( if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) { executorFailureEpoch(execId) = currentEpoch logInfo(s"Executor lost: $execId (epoch $currentEpoch)") + if (pushBasedShuffleEnabled) { + // Remove fetchFailed host in the shuffle push merger list for push based shuffle + hostToUnregisterOutputs.foreach( + host => blockManagerMaster.removeShufflePushMergerLocation(host)) + } blockManagerMaster.removeExecutor(execId) clearCacheLocs() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index a566d0a04387c..b2acdb3e12a6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.resource.ResourceProfile +import org.apache.spark.storage.BlockManagerId /** * A backend interface for scheduling systems that allows plugging in different ones under @@ -92,4 +93,16 @@ private[spark] trait SchedulerBackend { */ def maxNumConcurrentTasks(rp: ResourceProfile): Int + /** + * Get the list of host locations for push based shuffle + * + * Currently push based shuffle is disabled for both stage retry and stage reuse cases + * (for eg: in the case where few partitions are lost due to failure). Hence this method + * should be invoked only once for a ShuffleDependency. + * @return List of external shuffle services locations + */ + def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = Nil + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 49e32d04d450a..c6a4457d8f910 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -145,4 +145,6 @@ private[spark] object BlockManagerId { def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { blockManagerIdCache.get(id) } + + private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f544d47b8e13c..fe1a5aef9499c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -125,6 +125,26 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + /** + * Get a list of unique shuffle service locations where an executor is successfully + * registered in the past for block push/merge with push based shuffle. + */ + def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + driverEndpoint.askSync[Seq[BlockManagerId]]( + GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + } + + /** + * Remove the host from the candidate list of shuffle push mergers. This can be + * triggered if there is a FetchFailedException on the host + * @param host + */ + def removeShufflePushMergerLocation(host: String): Unit = { + driverEndpoint.askSync[Seq[BlockManagerId]](RemoveShufflePushMergerLocation(host)) + } + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index a7532a9870fae..4d565511704d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -74,6 +74,14 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + // Mapping from host name to shuffle (mergers) services where the current app + // registered an executor in the past. Older hosts are removed when the + // maxRetainedMergerLocations size is reached in favor of newer locations. + private val shuffleMergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]() + + // Maximum number of merger locations to cache + private val maxRetainedMergerLocations = conf.get(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS) + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100) private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) @@ -92,6 +100,8 @@ class BlockManagerMasterEndpoint( val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) + logInfo("BlockManagerMasterEndpoint up") // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED) // && conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)` @@ -139,6 +149,12 @@ class BlockManagerMasterEndpoint( case GetBlockStatus(blockId, askStorageEndpoints) => context.reply(blockStatus(blockId, askStorageEndpoints)) + case GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter) => + context.reply(getShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + + case RemoveShufflePushMergerLocation(host) => + context.reply(removeShufflePushMergerLocation(host)) + case IsExecutorAlive(executorId) => context.reply(blockManagerIdByExecutor.contains(executorId)) @@ -360,6 +376,17 @@ class BlockManagerMasterEndpoint( } + private def addMergerLocation(blockManagerId: BlockManagerId): Unit = { + if (!blockManagerId.isDriver && !shuffleMergerLocations.contains(blockManagerId.host)) { + val shuffleServerId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, + blockManagerId.host, externalShuffleServicePort) + if (shuffleMergerLocations.size >= maxRetainedMergerLocations) { + shuffleMergerLocations -= shuffleMergerLocations.head._1 + } + shuffleMergerLocations(shuffleServerId.host) = shuffleServerId + } + } + private def removeExecutor(execId: String): Unit = { logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) @@ -526,6 +553,10 @@ class BlockManagerMasterEndpoint( blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint, externalShuffleServiceBlockStatus) + + if (pushBasedShuffleEnabled) { + addMergerLocation(id) + } } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) @@ -657,6 +688,40 @@ class BlockManagerMasterEndpoint( } } + private def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + val blockManagerHosts = blockManagerIdByExecutor.values.map(_.host).toSet + val filteredBlockManagerHosts = blockManagerHosts.filterNot(hostsToFilter.contains(_)) + val filteredMergersWithExecutors = filteredBlockManagerHosts.map( + BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, _, externalShuffleServicePort)) + // Enough mergers are available as part of active executors list + if (filteredMergersWithExecutors.size >= numMergersNeeded) { + filteredMergersWithExecutors.toSeq + } else { + // Delta mergers added from inactive mergers list to the active mergers list + val filteredMergersWithExecutorsHosts = filteredMergersWithExecutors.map(_.host) + val filteredMergersWithoutExecutors = shuffleMergerLocations.values + .filterNot(x => hostsToFilter.contains(x.host)) + .filterNot(x => filteredMergersWithExecutorsHosts.contains(x.host)) + val randomFilteredMergersLocations = + if (filteredMergersWithoutExecutors.size > + numMergersNeeded - filteredMergersWithExecutors.size) { + Utils.randomize(filteredMergersWithoutExecutors) + .take(numMergersNeeded - filteredMergersWithExecutors.size) + } else { + filteredMergersWithoutExecutors + } + filteredMergersWithExecutors.toSeq ++ randomFilteredMergersLocations + } + } + + private def removeShufflePushMergerLocation(host: String): Unit = { + if (shuffleMergerLocations.contains(host)) { + shuffleMergerLocations.remove(host) + } + } + /** * Returns an [[RpcEndpointRef]] of the [[BlockManagerReplicaEndpoint]] for sending RPC messages. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index bbc076cea9ba8..afe416a55ed0d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -141,4 +141,10 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster + + case class GetShufflePushMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String]) + extends ToBlockManagerMaster + + case class RemoveShufflePushMergerLocation(host: String) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala index a3a528cddee37..4af48d5b9125c 100644 --- a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala @@ -136,12 +136,53 @@ private[spark] object HadoopFSUtils extends Logging { parallelismMax = 0) (path, leafFiles) }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) }.collect() } finally { sc.setJobDescription(previousJobDescription) } - statusMap.toSeq + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } } // scalastyle:off argcount @@ -291,4 +332,22 @@ private[spark] object HadoopFSUtils extends Logging { resolvedLeafStatuses } // scalastyle:on argcount + + /** A serializable variant of HDFS's BlockLocation. This is required by Hadoop 2.7. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. This is required by Hadoop 2.7. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b743ab6507117..6ccf65b737c1a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2541,6 +2541,14 @@ private[spark] object Utils extends Logging { master == "local" || master.startsWith("local[") } + /** + * Push based shuffle can only be enabled when external shuffle service is enabled. + */ + def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = { + conf.get(PUSH_BASED_SHUFFLE_ENABLED) && + (conf.get(IS_TESTING).getOrElse(false) || conf.get(SHUFFLE_SERVICE_ENABLED)) + } + /** * Return whether dynamic allocation is enabled in the given conf. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 55280fc578310..144489c5f7922 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -100,6 +100,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L) .set(Network.RPC_ASK_TIMEOUT, "5s") + .set(PUSH_BASED_SHUFFLE_ENABLED, true) } private def makeSortShuffleManager(): SortShuffleManager = { @@ -1974,6 +1975,48 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("SPARK-32919: Shuffle push merger locations should be bounded with in" + + " spark.shuffle.push.retainedMergerLocations") { + assert(master.getShufflePushMergerLocations(10, Set.empty).isEmpty) + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(10, Set.empty).size == 4) + assert(master.getShufflePushMergerLocations(10, Set.empty).map(_.host).sorted === + Seq("hostC", "hostD", "hostA", "hostB").sorted) + assert(master.getShufflePushMergerLocations(10, Set("hostB")).size == 3) + } + + test("SPARK-32919: Prefer active executor locations for shuffle push mergers") { + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(5, Set.empty).size == 4) + + master.removeExecutor("execA") + master.removeExecutor("execE") + + assert(master.getShufflePushMergerLocations(3, Set.empty).size == 3) + assert(master.getShufflePushMergerLocations(3, Set.empty).map(_.host).sorted === + Seq("hostC", "hostB", "hostD").sorted) + assert(master.getShufflePushMergerLocations(4, Set.empty).map(_.host).sorted === + Seq("hostB", "hostA", "hostC", "hostD").sorted) + } + test("SPARK-33387 Support ordered shuffle block migration") { val blocks: Seq[ShuffleBlockInfo] = Seq( ShuffleBlockInfo(1, 0L), @@ -1995,7 +2038,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(sortedBlocks.sameElements(decomManager.shufflesToMigrate.asScala.map(_._1))) } - class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { + class MockBlockTransferService( + val maxFailures: Int, + override val hostName: String = "MockBlockTransferServiceHost") extends BlockTransferService { var numCalls = 0 var tempFileManager: DownloadFileManager = null @@ -2013,8 +2058,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def close(): Unit = {} - override def hostName: String = { "MockBlockTransferServiceHost" } - override def port: Int = { 63332 } override def uploadBlock( diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 20624c743bc22..8fb408041ca9d 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.SparkListener import org.apache.spark.util.io.ChunkedByteBufferInputStream @@ -1432,6 +1433,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { }.getMessage assert(message.contains(expected)) } + + test("isPushBasedShuffleEnabled when both PUSH_BASED_SHUFFLE_ENABLED" + + " and SHUFFLE_SERVICE_ENABLED are true") { + val conf = new SparkConf() + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, false) + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(SHUFFLE_SERVICE_ENABLED, true) + assert(Utils.isPushBasedShuffleEnabled(conf) === true) + } } private class SimpleExtension diff --git a/docs/css/main.css b/docs/css/main.css index 8168a46f9a437..8b279a157c2b6 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -162,6 +162,7 @@ body .container-wrapper { margin-right: auto; border-radius: 15px; position: relative; + min-height: 100vh; } .title { @@ -264,6 +265,7 @@ a:hover code { max-width: 914px; line-height: 1.6; /* Inspired by Github's wiki style */ padding-left: 30px; + min-height: 100vh; } .dropdown-menu { @@ -325,6 +327,7 @@ a.anchorjs-link:hover { text-decoration: none; } border-bottom-width: 0px; margin-top: 0px; width: 210px; + height: 80%; float: left; position: fixed; overflow-y: scroll; diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index fd7208615a09f..870ed0aa0daaa 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -135,6 +135,7 @@ The behavior of some SQL functions can be different under ANSI mode (`spark.sql. - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. - `element_at`: This function throws `NoSuchElementException` if key does not exist in map. - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `parse_url`: This function throws `IllegalArgumentException` if an input string is not a valid url. ### SQL Operators diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 4b9acd0d39f3f..d086c8cdcc589 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.4.0): * {{{ * DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.4.0 - * ./build/sbt -Pdocker-integration-tests "testOnly *DB2IntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.DB2IntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index f1ffc8f0f3dc7..939a07238934b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., 2019-GA-ubuntu-16.04): * {{{ * MSSQLSERVER_DOCKER_IMAGE_NAME=2019-GA-ubuntu-16.04 - * ./build/sbt -Pdocker-integration-tests "testOnly *MsSqlServerIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 6f96ab33d0fee..68f0dbc057c1f 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., mysql:5.7.31): * {{{ * MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.31 - * ./build/sbt -Pdocker-integration-tests "testOnly *MySQLIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.MySQLIntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index fa13100b5fdc8..0347c98bba2c4 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., postgres:13.0): * {{{ * POSTGRES_DOCKER_IMAGE_NAME=postgres:13.0 - * ./build/sbt -Pdocker-integration-tests "testOnly *PostgresIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} */ @DockerTest diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index ad1010da5c104..03ebe0299f63f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp * The imputation strategy. Currently only "mean" and "median" are supported. * If "mean", then replace missing values using the mean value of the feature. * If "median", then replace missing values using the approximate median value of the feature. + * If "mode", then replace missing using the most frequent value of the feature. * Default: mean * * @group param */ final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + - s"If ${Imputer.median}, then replace missing values using the median value of the feature.", - ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) + s"If ${Imputer.median}, then replace missing values using the median value of the feature. " + + s"If ${Imputer.mode}, then replace missing values using the most frequent value of " + + s"the feature.", ParamValidators.inArray[String](Imputer.supportedStrategies)) /** @group getParam */ def getStrategy: String = $(strategy) @@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp * For example, if the input column is IntegerType (1, 2, 4, null), * the output will be IntegerType (1, 2, 4, 2) after mean imputation. * - * Note that the mean/median value is computed after filtering out missing values. + * Note that the mean/median/mode value is computed after filtering out missing values. * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. */ @@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) def setOutputCols(value: Array[String]): this.type = set(outputCols, value) /** - * Imputation strategy. Available options are ["mean", "median"]. + * Imputation strategy. Available options are ["mean", "median", "mode"]. * @group setParam */ @Since("2.2.0") @@ -151,39 +153,42 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) val spark = dataset.sparkSession val (inputColumns, _) = getInOutCols() - val cols = inputColumns.map { inputCol => when(col(inputCol).equalTo($(missingValue)), null) .when(col(inputCol).isNaN, null) .otherwise(col(inputCol)) - .cast("double") + .cast(DoubleType) .as(inputCol) } + val numCols = cols.length val results = $(strategy) match { case Imputer.mean => // Function avg will ignore null automatically. // For a column only containing null, avg will return null. val row = dataset.select(cols.map(avg): _*).head() - Array.range(0, inputColumns.length).map { i => - if (row.isNullAt(i)) { - Double.NaN - } else { - row.getDouble(i) - } - } + Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i)) case Imputer.median => // Function approxQuantile will ignore null automatically. // For a column only containing null, approxQuantile will return an empty array. dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError)) - .map { array => - if (array.isEmpty) { - Double.NaN - } else { - array.head - } - } + .map(_.headOption.getOrElse(Double.NaN)) + + case Imputer.mode => + import spark.implicits._ + // If there is more than one mode, choose the smallest one to keep in line + // with sklearn.impute.SimpleImputer (using scipy.stats.mode). + val modes = dataset.select(cols: _*).flatMap { row => + // Ignore null. + Iterator.range(0, numCols) + .flatMap(i => if (row.isNullAt(i)) None else Some((i, row.getDouble(i)))) + }.toDF("index", "value") + .groupBy("index", "value").agg(negate(count(lit(0))).as("negative_count")) + .groupBy("index").agg(min(struct("negative_count", "value")).as("mode")) + .select("index", "mode.value") + .as[(Int, Double)].collect().toMap + Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN)) } val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1) @@ -212,6 +217,10 @@ object Imputer extends DefaultParamsReadable[Imputer] { /** strategy names that Imputer currently supports. */ private[feature] val mean = "mean" private[feature] val median = "median" + private[feature] val mode = "mode" + + /* Set of strategies that Imputer supports */ + private[feature] val supportedStrategies = Array(mean, median, mode) @Since("2.2.0") override def load(path: String): Imputer = super.load(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index dfee2b4029c8b..30887f55638f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -28,13 +28,14 @@ import org.apache.spark.sql.types._ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0), - (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0), - (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0), - (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0) - )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1", - "expected_mean_value2", "expected_median_value2") + val df = spark.createDataFrame(Seq( + (0, 1.0, 4.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0), + (1, 11.0, 12.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0), + (2, 3.0, Double.NaN, 3.0, 3.0, 3.0, 10.0, 12.0, 4.0), + (3, Double.NaN, 14.0, 5.0, 3.0, 1.0, 14.0, 14.0, 14.0) + )).toDF("id", "value1", "value2", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1", + "expected_mean_value2", "expected_median_value2", "expected_mode_value2") val imputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("out1", "out2")) @@ -42,23 +43,25 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer for Double with default missing Value NaN") { - val df1 = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 11.0, 11.0, 11.0), - (2, 3.0, 3.0, 3.0), - (3, Double.NaN, 5.0, 3.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df1 = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 11.0, 11.0, 11.0, 11.0), + (2, 3.0, 3.0, 3.0, 3.0), + (3, Double.NaN, 5.0, 3.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer1 = new Imputer() .setInputCol("value") .setOutputCol("out") ImputerSuite.iterateStrategyTest(false, imputer1, df1) - val df2 = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 12.0, 12.0, 12.0), - (2, Double.NaN, 10.0, 12.0), - (3, 14.0, 14.0, 14.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df2 = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 12.0, 12.0, 12.0, 12.0), + (2, Double.NaN, 10.0, 12.0, 4.0), + (3, 14.0, 14.0, 14.0, 14.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer2 = new Imputer() .setInputCol("value") .setOutputCol("out") @@ -66,12 +69,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 3.0, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 1.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) ImputerSuite.iterateStrategyTest(true, imputer, df) @@ -79,64 +83,69 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Single Column: Imputer should handle NaNs when computing surrogate value," + " if missingValue is not NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 3.0, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 1.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setMissingValue(-1.0) ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer for Float with missing Value -1.0") { - val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0F, 1.0F), - (1, 3.0F, 3.0F, 3.0F), - (2, 10.0F, 10.0F, 10.0F), - (3, 10.0F, 10.0F, 10.0F), - (4, -1.0F, 6.0F, 3.0F) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0F, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F, 10.0F) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1) ImputerSuite.iterateStrategyTest(true, imputer, df) } test("Single Column: Imputer for Float with missing Value -1.0") { - val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0F, 1.0F), - (1, 3.0F, 3.0F, 3.0F), - (2, 10.0F, 10.0F, 10.0F), - (3, 10.0F, 10.0F, 10.0F), - (4, -1.0F, 6.0F, 3.0F) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0F, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F, 10.0F) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setMissingValue(-1) ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer should impute null as well as 'missingValue'") { - val rawDf = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 10.0, 10.0, 10.0), - (2, 10.0, 10.0, 10.0), - (3, Double.NaN, 8.0, 10.0), - (4, -1.0, 8.0, 10.0) - )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val rawDf = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0, 10.0), + (4, -1.0, 8.0, 10.0, 10.0) + )).toDF("id", "rawValue", + "expected_mean_value", "expected_median_value", "expected_mode_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) ImputerSuite.iterateStrategyTest(true, imputer, df) } test("Single Column: Imputer should impute null as well as 'missingValue'") { - val rawDf = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 10.0, 10.0, 10.0), - (2, 10.0, 10.0, 10.0), - (3, Double.NaN, 8.0, 10.0), - (4, -1.0, 8.0, 10.0) - )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val rawDf = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0, 10.0), + (4, -1.0, 8.0, 10.0, 10.0) + )).toDF("id", "rawValue", + "expected_mean_value", "expected_median_value", "expected_mode_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") ImputerSuite.iterateStrategyTest(false, imputer, df) @@ -187,7 +196,7 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer throws exception when surrogate cannot be computed") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, Double.NaN, 1.0, 1.0), (1, Double.NaN, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN) @@ -205,12 +214,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer throws exception when surrogate cannot be computed") { - val df = spark.createDataFrame( Seq( - (0, Double.NaN, 1.0, 1.0), - (1, Double.NaN, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") - Seq("mean", "median").foreach { strategy => + val df = spark.createDataFrame(Seq( + (0, Double.NaN, 1.0, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") + Seq("mean", "median", "mode").foreach { strategy => val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setStrategy(strategy) withClue("Imputer should fail all the values are invalid") { @@ -223,12 +233,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer input & output column validation") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, 1.0, 1.0, 1.0), (1, Double.NaN, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN) )).toDF("id", "value1", "value2", "value3") - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => withClue("Imputer should fail if inputCols and outputCols are different length") { val e: IllegalArgumentException = intercept[IllegalArgumentException] { val imputer = new Imputer().setStrategy(strategy) @@ -306,13 +316,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer for IntegerType with default missing value null") { - - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (null, 5, 3) - )).toDF("value1", "expected_mean_value1", "expected_median_value1") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (null, 5, 3, 1) + )).toDF("value1", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -327,12 +337,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column Imputer for IntegerType with default missing value null") { - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (null, 5, 3) - )).toDF("value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (null, 5, 3, 1) + )).toDF("value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer() .setInputCol("value") @@ -347,13 +358,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer for IntegerType with missing value -1") { - - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (-1, 5, 3) - )).toDF("value1", "expected_mean_value1", "expected_median_value1") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (-1, 5, 3, 1) + )).toDF("value1", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -369,12 +380,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer for IntegerType with missing value -1") { - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (-1, 5, 3) - )).toDF("value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (-1, 5, 3, 1) + )).toDF("value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer() .setInputCol("value") @@ -402,13 +414,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Compare single/multiple column(s) Imputer in pipeline") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, 1.0, 4.0), (1, 11.0, 12.0), (2, 3.0, Double.NaN), (3, Double.NaN, 14.0) )).toDF("id", "value1", "value2") - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => val multiColsImputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("result1", "result2")) @@ -450,11 +462,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { object ImputerSuite { /** - * Imputation strategy. Available options are ["mean", "median"]. - * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" + * Imputation strategy. Available options are ["mean", "median", "mode"]. + * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median", + * "expected_mode". */ def iterateStrategyTest(isMultiCol: Boolean, imputer: Imputer, df: DataFrame): Unit = { - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) val resultDF = model.transform(df) diff --git a/pom.xml b/pom.xml index 3ae2e7420e154..85cf5a00b0b24 100644 --- a/pom.xml +++ b/pom.xml @@ -164,7 +164,6 @@ 3.2.2 2.12.10 2.12 - -Ywarn-unused-import 2.0.0 --test @@ -2538,7 +2537,6 @@ -deprecation -feature -explaintypes - ${scalac.arg.unused-imports} -target:jvm-1.8 @@ -3262,13 +3260,12 @@ - + scala-2.13 2.13.3 2.13 - -Wconf:cat=unused-imports:e diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 55c87fcb3aaa2..05413b7091ad9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -221,6 +221,7 @@ object SparkBuild extends PomBuild { Seq( "-Xfatal-warnings", "-deprecation", + "-Ywarn-unused-import", "-P:silencer:globalFilters=.*deprecated.*" //regex to catch deprecation warnings and supress them ) } else { @@ -230,6 +231,8 @@ object SparkBuild extends PomBuild { // see `scalac -Wconf:help` for details "-Wconf:cat=deprecation:wv,any:e", // 2.13-specific warning hits to be muted (as narrowly as possible) and addressed separately + // TODO(SPARK-33499): Enable this option when Scala 2.12 is no longer supported. + // "-Wunused:imports", "-Wconf:cat=lint-multiarg-infix:wv", "-Wconf:cat=other-nullary-override:wv", "-Wconf:cat=other-match-analysis&site=org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupFunction.catalogFunction:wv", diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4d898bd5fffa8..82b9a6db1eb92 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1507,7 +1507,8 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has strategy = Param(Params._dummy(), "strategy", "strategy for imputation. If mean, then replace missing values using the mean " "value of the feature. If median, then replace missing values using the " - "median value of the feature.", + "median value of the feature. If mode, then replace missing using the most " + "frequent value of the feature.", typeConverter=TypeConverters.toString) missingValue = Param(Params._dummy(), "missingValue", @@ -1541,7 +1542,7 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable): numeric type. Currently Imputer does not support categorical features and possibly creates incorrect values for a categorical feature. - Note that the mean/median value is computed after filtering out missing values. + Note that the mean/median/mode value is computed after filtering out missing values. All Null values in the input columns are treated as missing, and so are also imputed. For computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a relative error of `0.001`. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index b42bdb9816600..22002bb32004d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.EnumSet -import java.util.concurrent.atomic.{AtomicBoolean} +import java.util.concurrent.atomic.AtomicBoolean import javax.servlet.DispatcherType import scala.concurrent.{ExecutionContext, Future} @@ -29,14 +29,14 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config.UI._ import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{RpcUtils, ThreadUtils} +import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Abstract Yarn scheduler backend that contains common logic @@ -80,6 +80,18 @@ private[spark] abstract class YarnSchedulerBackend( /** Attempt ID. This is unset for client-mode schedulers */ private var attemptId: Option[ApplicationAttemptId] = None + private val blockManagerMaster: BlockManagerMaster = sc.env.blockManager.master + + private val minMergersThresholdRatio = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO) + + private val minMergersStaticThreshold = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD) + + private val maxNumExecutors = conf.get(config.DYN_ALLOCATION_MAX_EXECUTORS) + + private val numExecutors = conf.get(config.EXECUTOR_INSTANCES).getOrElse(0) + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -161,6 +173,36 @@ private[spark] abstract class YarnSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + override def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = { + // TODO (SPARK-33481) This is a naive way of calculating numMergersDesired for a stage, + // TODO we can use better heuristics to calculate numMergersDesired for a stage. + val maxExecutors = if (Utils.isDynamicAllocationEnabled(sc.getConf)) { + maxNumExecutors + } else { + numExecutors + } + val tasksPerExecutor = sc.resourceProfileManager + .resourceProfileFromId(resourceProfileId).maxTasksPerExecutor(sc.conf) + val numMergersDesired = math.min( + math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxExecutors) + val minMergersNeeded = math.max(minMergersStaticThreshold, + math.floor(numMergersDesired * minMergersThresholdRatio).toInt) + + // Request for numMergersDesired shuffle mergers to BlockManagerMasterEndpoint + // and if it's less than minMergersNeeded, we disable push based shuffle. + val mergerLocations = blockManagerMaster + .getShufflePushMergerLocations(numMergersDesired, scheduler.excludedNodes()) + if (mergerLocations.size < numMergersDesired && mergerLocations.size < minMergersNeeded) { + Seq.empty[BlockManagerId] + } else { + logDebug(s"The number of shuffle mergers desired ${numMergersDesired}" + + s" and available locations are ${mergerLocations.length}") + mergerLocations + } + } + /** * Add filters to the SparkUI. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala new file mode 100644 index 0000000000000..c680502cb328f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, GroupingID} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{AbstractDataType, DataType, StructType} + +/** + * Object for grouping all error messages of the query compilation. + * Currently it includes all AnalysisExcpetions created and thrown directly in + * org.apache.spark.sql.catalyst.analysis.Analyzer. + */ +object QueryCompilationErrors { + def groupingIDMismatchError(groupingID: GroupingID, groupByExprs: Seq[Expression]): Throwable = { + new AnalysisException( + s"Columns of grouping_id (${groupingID.groupByExprs.mkString(",")}) " + + s"does not match grouping columns (${groupByExprs.mkString(",")})") + } + + def groupingColInvalidError(groupingCol: Expression, groupByExprs: Seq[Expression]): Throwable = { + new AnalysisException( + s"Column of grouping ($groupingCol) can't be found " + + s"in grouping columns ${groupByExprs.mkString(",")}") + } + + def groupingSizeTooLargeError(sizeLimit: Int): Throwable = { + new AnalysisException( + s"Grouping sets size cannot be greater than $sizeLimit") + } + + def unorderablePivotColError(pivotCol: Expression): Throwable = { + new AnalysisException( + s"Invalid pivot column '$pivotCol'. Pivot columns must be comparable." + ) + } + + def nonLiteralPivotValError(pivotVal: Expression): Throwable = { + new AnalysisException( + s"Literal expressions required for pivot values, found '$pivotVal'") + } + + def pivotValDataTypeMismatchError(pivotVal: Expression, pivotCol: Expression): Throwable = { + new AnalysisException( + s"Invalid pivot value '$pivotVal': " + + s"value data type ${pivotVal.dataType.simpleString} does not match " + + s"pivot column data type ${pivotCol.dataType.catalogString}") + } + + def unsupportedIfNotExistsError(tableName: String): Throwable = { + new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: $tableName") + } + + def nonPartitionColError(partitionName: String): Throwable = { + new AnalysisException( + s"PARTITION clause cannot contain a non-partition column name: $partitionName") + } + + def addStaticValToUnknownColError(staticName: String): Throwable = { + new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + } + + def unknownStaticPartitionColError(name: String): Throwable = { + new AnalysisException(s"Unknown static partition column: $name") + } + + def nestedGeneratorError(trimmedNestedGenerator: Expression): Throwable = { + new AnalysisException( + "Generators are not supported when it's nested in " + + "expressions, but got: " + toPrettySQL(trimmedNestedGenerator)) + } + + def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = { + new AnalysisException( + s"Only one generator allowed per $clause clause but found " + + generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + } + + def generatorOutsideSelectError(plan: LogicalPlan): Throwable = { + new AnalysisException( + "Generators are not supported outside the SELECT clause, but " + + "got: " + plan.simpleString(SQLConf.get.maxToStringFields)) + } + + def legacyStoreAssignmentPolicyError(): Throwable = { + val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key + new AnalysisException( + "LEGACY store assignment policy is disallowed in Spark data source V2. " + + s"Please set the configuration $configKey to other values.") + } + + def unresolvedUsingColForJoinError( + colName: String, plan: LogicalPlan, side: String): Throwable = { + new AnalysisException( + s"USING column `$colName` cannot be resolved on the $side " + + s"side of the join. The $side-side columns: [${plan.output.map(_.name).mkString(", ")}]") + } + + def dataTypeMismatchForDeserializerError( + dataType: DataType, desiredType: String): Throwable = { + val quantifier = if (desiredType.equals("array")) "an" else "a" + new AnalysisException( + s"need $quantifier $desiredType field but got " + dataType.catalogString) + } + + def fieldNumberMismatchForDeserializerError( + schema: StructType, maxOrdinal: Int): Throwable = { + new AnalysisException( + s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.") + } + + def upCastFailureError( + fromStr: String, from: Expression, to: DataType, walkedTypePath: Seq[String]): Throwable = { + new AnalysisException( + s"Cannot up cast $fromStr from " + + s"${from.dataType.catalogString} to ${to.catalogString}.\n" + + s"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + def unsupportedAbstractDataTypeForUpCastError(gotType: AbstractDataType): Throwable = { + new AnalysisException( + s"UpCast only support DecimalType as AbstractDataType yet, but got: $gotType") + } + + def outerScopeFailureForNewInstanceError(className: String): Throwable = { + new AnalysisException( + s"Unable to generate an encoder for inner class `$className` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + + def referenceColNotFoundForAlterTableChangesError( + after: TableChange.After, parentName: String): Throwable = { + new AnalysisException( + s"Couldn't find the reference column for $after at $parentName") + } + +} + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8d95d8cf49d45..53c0ff687c6d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy} @@ -448,9 +449,7 @@ class Analyzer(override val catalogManager: CatalogManager) e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { Alias(gid, toPrettySQL(e))() } else { - throw new AnalysisException( - s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + - s"grouping columns (${groupByExprs.mkString(",")})") + throw QueryCompilationErrors.groupingIDMismatchError(e, groupByExprs) } case e @ Grouping(col: Expression) => val idx = groupByExprs.indexWhere(_.semanticEquals(col)) @@ -458,8 +457,7 @@ class Analyzer(override val catalogManager: CatalogManager) Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1L)), ByteType), toPrettySQL(e))() } else { - throw new AnalysisException(s"Column of grouping ($col) can't be found " + - s"in grouping columns ${groupByExprs.mkString(",")}") + throw QueryCompilationErrors.groupingColInvalidError(col, groupByExprs) } } } @@ -575,8 +573,7 @@ class Analyzer(override val catalogManager: CatalogManager) val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs) if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) { - throw new AnalysisException( - s"Grouping sets size cannot be greater than ${GroupingID.dataType.defaultSize * 8}") + throw QueryCompilationErrors.groupingSizeTooLargeError(GroupingID.dataType.defaultSize * 8) } // Expand works by setting grouping expressions to null as determined by the @@ -712,8 +709,7 @@ class Analyzer(override val catalogManager: CatalogManager) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => if (!RowOrdering.isOrderable(pivotColumn.dataType)) { - throw new AnalysisException( - s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) } // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) @@ -724,13 +720,10 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => value.foldable } if (!foldable) { - throw new AnalysisException( - s"Literal expressions required for pivot values, found '$value'") + throw QueryCompilationErrors.nonLiteralPivotValError(value) } if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { - throw new AnalysisException(s"Invalid pivot value '$value': " + - s"value data type ${value.dataType.simpleString} does not match " + - s"pivot column data type ${pivotColumn.dataType.catalogString}") + throw QueryCompilationErrors.pivotValDataTypeMismatchError(value, pivotColumn) } Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) } @@ -1167,8 +1160,7 @@ class Analyzer(override val catalogManager: CatalogManager) case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved => // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { - throw new AnalysisException( - s"Cannot write, IF NOT EXISTS is not supported for table: ${r.table.name}") + throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name) } val partCols = partitionColumnNames(r.table) @@ -1205,8 +1197,7 @@ class Analyzer(override val catalogManager: CatalogManager) partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { case Some(_) => case None => - throw new AnalysisException( - s"PARTITION clause cannot contain a non-partition column name: $partitionName") + throw QueryCompilationErrors.nonPartitionColError(partitionName) } } } @@ -1228,8 +1219,7 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(attr) => attr.name -> staticName case _ => - throw new AnalysisException( - s"Cannot add static value for unknown column: $staticName") + throw QueryCompilationErrors.addStaticValToUnknownColError(staticName) }).toMap val queryColumns = query.output.iterator @@ -1271,7 +1261,7 @@ class Analyzer(override val catalogManager: CatalogManager) // an UnresolvedAttribute. EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) case None => - throw new AnalysisException(s"Unknown static partition column: $name") + throw QueryCompilationErrors.unknownStaticPartitionColError(name) } }.reduce(And) } @@ -2483,23 +2473,19 @@ class Analyzer(override val catalogManager: CatalogManager) def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get - throw new AnalysisException("Generators are not supported when it's nested in " + - "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Project(projectList, _) if projectList.count(hasGenerator) > 1 => val generators = projectList.filter(hasGenerator).map(trimAlias) - throw new AnalysisException("Only one generator allowed per select clause but found " + - generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "select") case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) => val nestedGenerator = aggList.find(hasNestedGenerator).get - throw new AnalysisException("Generators are not supported when it's nested in " + - "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 => val generators = aggList.filter(hasGenerator).map(trimAlias) - throw new AnalysisException("Only one generator allowed per aggregate clause but found " + - generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate") case agg @ Aggregate(groupList, aggList, child) if aggList.forall { case AliasedGenerator(_, _, _) => true @@ -2582,8 +2568,7 @@ class Analyzer(override val catalogManager: CatalogManager) case g: Generate => g case p if p.expressions.exists(hasGenerator) => - throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + - "got: " + p.simpleString(SQLConf.get.maxToStringFields)) + throw QueryCompilationErrors.generatorOutsideSelectError(p) } } @@ -3122,10 +3107,7 @@ class Analyzer(override val catalogManager: CatalogManager) private def validateStoreAssignmentPolicy(): Unit = { // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { - val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key - throw new AnalysisException(s""" - |"LEGACY" store assignment policy is disallowed in Spark data source V2. - |Please set the configuration $configKey to other values.""".stripMargin) + throw QueryCompilationErrors.legacyStoreAssignmentPolicyError() } } @@ -3138,14 +3120,12 @@ class Analyzer(override val catalogManager: CatalogManager) hint: JoinHint) = { val leftKeys = joinNames.map { keyName => left.output.find(attr => resolver(attr.name, keyName)).getOrElse { - throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + - s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]") + throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, left, "left") } } val rightKeys = joinNames.map { keyName => right.output.find(attr => resolver(attr.name, keyName)).getOrElse { - throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " + - s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]") + throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, right, "right") } } val joinPairs = leftKeys.zip(rightKeys) @@ -3208,7 +3188,8 @@ class Analyzer(override val catalogManager: CatalogManager) ExtractValue(child, fieldName, resolver) } case other => - throw new AnalysisException("need an array field but got " + other.catalogString) + throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, + "array") } case u: UnresolvedCatalystToExternalMap if u.child.resolved => u.child.dataType match { @@ -3218,7 +3199,7 @@ class Analyzer(override val catalogManager: CatalogManager) ExtractValue(child, fieldName, resolver) } case other => - throw new AnalysisException("need a map field but got " + other.catalogString) + throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, "map") } } validateNestedTupleFields(result) @@ -3227,8 +3208,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def fail(schema: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" + - ", but failed as the number of fields does not line up.") + throw QueryCompilationErrors.fieldNumberMismatchForDeserializerError(schema, maxOrdinal) } /** @@ -3287,10 +3267,7 @@ class Analyzer(override val catalogManager: CatalogManager) case n: NewInstance if n.childrenResolved && !n.resolved => val outer = OuterScopes.getOuterScope(n.cls) if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + - "access to the scope that this class was defined in.\n" + - "Try moving this class out of its parent class.") + throw QueryCompilationErrors.outerScopeFailureForNewInstanceError(n.cls.getName) } n.copy(outerPointer = Some(outer)) } @@ -3306,11 +3283,7 @@ class Analyzer(override val catalogManager: CatalogManager) case l: LambdaVariable => "array element" case e => e.sql } - throw new AnalysisException(s"Cannot up cast $fromStr from " + - s"${from.dataType.catalogString} to ${to.catalogString}.\n" + - "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + - "You can either add an explicit cast to the input data or choose a higher precision " + - "type of the field in the target object") + throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath) } def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { @@ -3321,8 +3294,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UpCast(child, _, _) if !child.resolved => u case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => - throw new AnalysisException( - s"UpCast only support DecimalType as AbstractDataType yet, but got: $target") + throw QueryCompilationErrors.unsupportedAbstractDataTypeForUpCastError(target) case UpCast(child, target, walkedTypePath) if target == DecimalType && child.dataType.isInstanceOf[DecimalType] => @@ -3501,8 +3473,8 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(colName) => ColumnPosition.after(colName) case None => - throw new AnalysisException("Couldn't find the reference column for " + - s"$after at $parentName") + throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError(after, + parentName) } case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 16e22940495f1..9f92181b34df1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1357,8 +1357,9 @@ object ParseUrl { 1 """, since = "2.0.0") -case class ParseUrl(children: Seq[Expression]) +case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression with ExpectsInputTypes with CodegenFallback { + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) @@ -1404,7 +1405,9 @@ case class ParseUrl(children: Seq[Expression]) try { new URI(url.toString) } catch { - case e: URISyntaxException => null + case e: URISyntaxException if failOnError => + throw new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e) + case _: URISyntaxException => null } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index a1b6cec24f23f..730574a4b9846 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -943,6 +943,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil) } + test("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val msg = intercept[IllegalArgumentException] { + evaluateWithoutCodegen( + ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST"))) + }.getMessage + assert(msg.contains("Find an invaild url string")) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation( + ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST")), null) + } + } + test("Sentences") { val nullString = Literal.create(null, StringType) checkEvaluation(Sentences(nullString, nullString, nullString), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala index 1c96bdf3afa20..23987e909aa70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala @@ -92,4 +92,8 @@ class InMemoryPartitionTable( override def partitionExists(ident: InternalRow): Boolean = memoryTablePartitions.containsKey(ident) + + override protected def addPartitionKey(key: Seq[Any]): Unit = { + memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 3b47271a114e2..c93053abc550a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -160,12 +160,15 @@ class InMemoryTable( } } + protected def addPartitionKey(key: Seq[Any]): Unit = {} + def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => val key = getKey(row) dataMap += dataMap.get(key) .map(key -> _.withRow(row)) .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row)) + addPartitionKey(key) }) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 21abfc2816ee4..e5c29312b80e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -147,6 +147,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( + session, staging, ident, parts, @@ -157,6 +158,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat orCreate = orCreate) :: Nil case _ => ReplaceTableAsSelectExec( + session, catalog, ident, parts, @@ -170,9 +172,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil + AppendDataExecV1(v1, writeOptions.asOptions, query, r) :: Nil case v2 => - AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil + AppendDataExec(session, v2, r, writeOptions.asOptions, planLater(query)) :: Nil } case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => @@ -184,14 +186,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil + OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query, r) :: Nil case v2 => - OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil + OverwriteByExpressionExec(session, v2, r, filters, + writeOptions.asOptions, planLater(query)) :: Nil } case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => OverwritePartitionsDynamicExec( - r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil + session, r.table.asWritable, r, writeOptions.asOptions, planLater(query)) :: Nil case DeleteFromTable(relation, condition) => relation match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 560da39314b36..af7721588edeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -37,10 +37,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class AppendDataExecV1( table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { override protected def run(): Seq[InternalRow] = { - writeWithV1(newWriteBuilder().buildForV1Write()) + writeWithV1(newWriteBuilder().buildForV1Write(), Some(v2Relation)) } } @@ -59,7 +60,8 @@ case class OverwriteByExpressionExecV1( table: SupportsWrite, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { private def isTruncate(filters: Array[Filter]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] @@ -68,10 +70,10 @@ case class OverwriteByExpressionExecV1( override protected def run(): Seq[InternalRow] = { newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => - writeWithV1(builder.truncate().asV1Builder.buildForV1Write()) + writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), Some(v2Relation)) case builder: SupportsOverwrite => - writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write()) + writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(), Some(v2Relation)) case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") @@ -112,9 +114,14 @@ sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { trait SupportsV1Write extends SparkPlan { def plan: LogicalPlan - protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = { + protected def writeWithV1( + relation: InsertableRelation, + v2Relation: Option[DataSourceV2Relation] = None): Seq[InternalRow] = { + val session = sqlContext.sparkSession // The `plan` is already optimized, we should not analyze and optimize it again. - relation.insert(AlreadyOptimized.dataFrame(sqlContext.sparkSession, plan), overwrite = false) + relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false) + v2Relation.foreach(r => session.sharedState.cacheManager.recacheByPlan(session, r)) + Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 1421a9315c3a8..1648134d0a1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -127,6 +128,7 @@ case class AtomicCreateTableAsSelectExec( * ReplaceTableAsSelectStagingExec. */ case class ReplaceTableAsSelectExec( + session: SparkSession, catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -146,6 +148,8 @@ case class ReplaceTableAsSelectExec( // 2. Writing to the new table fails, // 3. The table returned by catalog.createTable doesn't support writing. if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -169,6 +173,7 @@ case class ReplaceTableAsSelectExec( * is left untouched. */ case class AtomicReplaceTableAsSelectExec( + session: SparkSession, catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -180,6 +185,10 @@ case class AtomicReplaceTableAsSelectExec( override protected def run(): Seq[InternalRow] = { val schema = query.schema.asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) + } val staged = if (orCreate) { catalog.stageCreateOrReplace( ident, schema, partitioning.toArray, properties.asJava) @@ -204,12 +213,16 @@ case class AtomicReplaceTableAsSelectExec( * Rows in the output data set are appended. */ case class AppendDataExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - writeWithV2(newWriteBuilder().buildForBatch()) + val writtenRows = writeWithV2(newWriteBuilder().buildForBatch()) + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -224,7 +237,9 @@ case class AppendDataExec( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -234,7 +249,7 @@ case class OverwriteByExpressionExec( } override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => writeWithV2(builder.truncate().buildForBatch()) @@ -244,9 +259,12 @@ case class OverwriteByExpressionExec( case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } + /** * Physical plan node for dynamic partition overwrite into a v2 table. * @@ -257,18 +275,22 @@ case class OverwriteByExpressionExec( * are not modified. */ case class OverwritePartitionsDynamicExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsDynamicOverwrite => writeWithV2(builder.overwriteDynamicPartitions().buildForBatch()) case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -370,6 +392,15 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { Nil } + + protected def uncacheTable( + session: SparkSession, + catalog: TableCatalog, + table: Table, + ident: Identifier): Unit = { + val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true) + } } object DataWritingSparkTask extends Logging { @@ -484,3 +515,4 @@ private[v2] case class DataWritingSparkTaskResult( * Sink progress information collected after commit. */ private[sql] case class StreamWriterCommitProgress(numOutputRows: Long) + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index 80c44159e8248..d9636f793476c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -199,8 +199,8 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) val graphUIDataForNumRowsDroppedByWatermark = new GraphUIData( - "aggregated-num-state-rows-dropped-by-watermark-timeline", - "aggregated-num-state-rows-dropped-by-watermark-histogram", + "aggregated-num-rows-dropped-by-watermark-timeline", + "aggregated-num-rows-dropped-by-watermark-histogram", numRowsDroppedByWatermarkData, minBatchTime, maxBatchTime, @@ -241,11 +241,11 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab)
-
Aggregated Number Of State Rows Dropped By Watermark {SparkUIUtils.tooltip("Aggregated number of state rows dropped by watermark.", "right")}
+
Aggregated Number Of Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all input rows being dropped in stateful operators by watermark. 'Inputs' are relative to operators.", "right")}
- {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} - {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} + {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} + {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} // scalastyle:on diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index ddafa1bb5070a..90df4ee08bfc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataSourceV2SQLSuite @@ -43,7 +44,6 @@ class DataSourceV2SQLSuite with AlterTableTests with DatasourceV2SQLBase { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ private val v2Source = classOf[FakeV2Provider].getName override protected val v2Format = v2Source @@ -782,6 +782,84 @@ class DataSourceV2SQLSuite } } + test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") { + Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"REPLACE TABLE $t USING foo AS SELECT id FROM source") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty) + } + } + } + } + + test("SPARK-33492: AppendData should refresh cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a")).toDF("i", "j").write.saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b")).toDF("i", "j").write.mode(SaveMode.Append).saveAsTable(t) + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Row(2, "b") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Row(2) :: Nil) + } + } + } + + test("SPARK-33492: OverwriteByExpression should refresh cache") { + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"INSERT OVERWRITE TABLE $t VALUES (1, 'a')") + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + } + } + } + + test("SPARK-33492: OverwritePartitionsDynamic should refresh cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a", 1)).toDF("i", "j", "k").write.partitionBy("k") saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b", 1)).toDF("i", "j", "k").writeTo(t).overwritePartitions() + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(2, "b", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(2) :: Nil) + } + } + } + test("Relation: basic") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { @@ -1980,57 +2058,6 @@ class DataSourceV2SQLSuite } } - test("ALTER TABLE RECOVER PARTITIONS") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - val e = intercept[AnalysisException] { - sql(s"ALTER TABLE $t RECOVER PARTITIONS") - } - assert(e.message.contains("ALTER TABLE RECOVER PARTITIONS is only supported with v1 tables")) - } - } - - test("ALTER TABLE ADD PARTITION") { - val t = "testpart.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'") - - val partTable = catalog("testpart").asTableCatalog - .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable] - assert(partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) - - val partMetadata = partTable.loadPartitionMetadata(InternalRow.fromSeq(Seq(1))) - assert(partMetadata.containsKey("location")) - assert(partMetadata.get("location") == "loc") - } - } - - test("ALTER TABLE RENAME PARTITION") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - val e = intercept[AnalysisException] { - sql(s"ALTER TABLE $t PARTITION (id=1) RENAME TO PARTITION (id=2)") - } - assert(e.message.contains("ALTER TABLE RENAME PARTITION is only supported with v1 tables")) - } - } - - test("ALTER TABLE DROP PARTITION") { - val t = "testpart.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'") - spark.sql(s"ALTER TABLE $t DROP PARTITION (id=1)") - - val partTable = - catalog("testpart").asTableCatalog.loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) - assert(!partTable.asPartitionable.partitionExists(InternalRow.fromSeq(Seq(1)))) - } - } - test("ALTER TABLE SerDe properties") { val t = "testcat.ns1.ns2.tbl" withTable(t) { @@ -2513,6 +2540,25 @@ class DataSourceV2SQLSuite } } + test("SPARK-33505: insert into partitioned table") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + sql(s""" + |CREATE TABLE $t (id bigint, city string, data string) + |USING foo + |PARTITIONED BY (id, city)""".stripMargin) + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable] + val expectedPartitionIdent = InternalRow.fromSeq(Seq(1, UTF8String.fromString("NY"))) + assert(!partTable.partitionExists(expectedPartitionIdent)) + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'abc'") + assert(partTable.partitionExists(expectedPartitionIdent)) + // Insert into the existing partition must not fail + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'def'") + assert(partTable.partitionExists(expectedPartitionIdent)) + } + } + private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 4b52a4cbf4116..cba7dd35fb3bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -24,14 +24,17 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources._ @@ -145,6 +148,52 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before SparkSession.setDefaultSession(spark) } } + + test("SPARK-33492: append fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").append() + checkAnswer(session.read.table("test"), Row(1, "x") :: Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } + + test("SPARK-33492: overwrite fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").overwrite(lit(true)) + checkAnswer(session.read.table("test"), Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } } class V1WriteFallbackSessionCatalogSuite @@ -177,6 +226,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) + tables.put(Identifier.of(Array("default"), name), t) t } } @@ -272,7 +322,7 @@ class InMemoryTableWithV1Fallback( override val partitioning: Array[Transform], override val properties: util.Map[String, String]) extends Table - with SupportsWrite { + with SupportsWrite with SupportsRead { partitioning.foreach { t => if (!t.isInstanceOf[IdentityTransform]) { @@ -281,6 +331,7 @@ class InMemoryTableWithV1Fallback( } override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.TRUNCATE).asJava @@ -338,6 +389,30 @@ class InMemoryTableWithV1Fallback( } } } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new V1ReadFallbackScanBuilder(schema) + + private class V1ReadFallbackScanBuilder(schema: StructType) extends ScanBuilder { + override def build(): Scan = new V1ReadFallbackScan(schema) + } + + private class V1ReadFallbackScan(schema: StructType) extends V1Scan { + override def readSchema(): StructType = schema + override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = + new V1TableScan(context, schema).asInstanceOf[T] + } + + private class V1TableScan( + context: SQLContext, + requiredSchema: StructType) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = context + override def schema: StructType = requiredSchema + override def buildScan(): RDD[Row] = { + val data = InMemoryV1Provider.getTableData(context.sparkSession, name).collect() + context.sparkContext.makeRDD(data) + } + } } /** A rule that fails if a query plan is analyzed twice. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala index 2aaeb67d30538..94844c4e87a84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala @@ -141,7 +141,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B summaryText should contain ("Aggregated Number Of Total State Rows (?)") summaryText should contain ("Aggregated Number Of Updated State Rows (?)") summaryText should contain ("Aggregated State Memory Used In Bytes (?)") - summaryText should contain ("Aggregated Number Of State Rows Dropped By Watermark (?)") + summaryText should contain ("Aggregated Number Of Rows Dropped By Watermark (?)") summaryText should contain ("Aggregated Custom Metric stateOnCurrentVersionSizeBytes" + " (?)") summaryText should not contain ("Aggregated Custom Metric loadedMapCacheHitCount (?)")