From 3bcda6d1c193dcf52b2b637a79497f43a7658809 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Fri, 11 Oct 2024 17:28:12 +0200 Subject: [PATCH] optimizzzzze --- .../spark/sql/artifact/ArtifactManager.scala | 134 +++++++++++++----- .../execution/streaming/StreamExecution.scala | 4 +- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala index 8dc1cac0f26bb..10ab1e02e249f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption} import java.util.concurrent.CopyOnWriteArrayList +import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -72,9 +73,11 @@ class ArtifactManager(session: SparkSession) extends Logging { (ArtifactUtils.concatenatePaths(artifactPath, "classes"), s"$artifactURI${File.separator}classes${File.separator}") - protected[sql] val state: JobArtifactState = { + protected[sql] lazy val state: JobArtifactState = { val sessionIsolated = session.conf.get("spark.session.isolate.artifacts", "true") val replIsolated = session.conf.get("spark.repl.isolate.artifacts", "false") + logInfo(s"Session isolation: $sessionIsolated, REPL isolation: $replIsolated." + + s" Session UUID: ${session.sessionUUID}, REPL class URI: $replClassURI.") (sessionIsolated, replIsolated) match { case ("true", "true") => JobArtifactState(session.sessionUUID, Some(replClassURI)) case ("true", "false") => JobArtifactState(session.sessionUUID, None) @@ -86,29 +89,23 @@ class ArtifactManager(session: SparkSession) extends Logging { private var initialContextResourcesCopied = false - /** The number of JARs added at the time of the last [[withResources]] call. */ - private val lastNumberOfJarsInScope = new ThreadLocal[Int] { - override def initialValue(): Int = -1 - } + /** + * The number of JARs for the last [[classloader]] call, and all generated class loaders. + * Value of the HashMap is not used. + */ + private val numberOfJarsAndCachedClassLoaders = + new ThreadLocal[(Int, mutable.WeakHashMap[ClassLoader, Boolean])] { + override def initialValue(): (Int, mutable.WeakHashMap[ClassLoader, Boolean]) = + (-1, new mutable.WeakHashMap[ClassLoader, Boolean]()) + } def withResources[T](f: => T): T = { - // Generating class loader is slow, and often withResources is layered multiple times since - // SparkSession.withActive is often layered. We have to do some check here to avoid layering - // too many class loaders. Layering too much heavily impacts streaming performance. - // jarsList is append-only, so we can use its size to decide whether we need a new layer. - val prevNumOfJars = lastNumberOfJarsInScope.get() - if (prevNumOfJars == jarsList.size()) { - try f finally { lastNumberOfJarsInScope.set(prevNumOfJars) } - } else { - lastNumberOfJarsInScope.set(jarsList.size()) - - Utils.withContextClassLoader(classloader, retainChange = true) { - JobArtifactSet.withActiveJobArtifactState(state) { - // Copy over global initial resources to this session. Often used by spark-submit. - copyInitialContextResourcesIfNeeded() + Utils.withContextClassLoader(classloader, retainChange = true) { + JobArtifactSet.withActiveJobArtifactState(state) { + // Copy over global initial resources to this session. Often used by spark-submit. + copyInitialContextResourcesIfNeeded() - try f finally { lastNumberOfJarsInScope.set(prevNumOfJars) } - } + f } } } @@ -133,6 +130,8 @@ class ArtifactManager(session: SparkSession) extends Logging { protected val cachedBlockIdList = new CopyOnWriteArrayList[CacheId] protected val jarsList = new CopyOnWriteArrayList[Path] protected val pythonIncludeList = new CopyOnWriteArrayList[String] + protected val sparkContextRelativePaths = + new CopyOnWriteArrayList[(SparkContextResourceType.ResourceType, Path, Option[String])] /** * Get the URLs of all jar artifacts. @@ -241,9 +240,13 @@ class ArtifactManager(session: SparkSession) extends Logging { if (normalizedRemoteRelativePath.startsWith(s"jars${File.separator}")) { session.sparkContext.addJar(uri) + sparkContextRelativePaths.add( + (SparkContextResourceType.JAR, normalizedRemoteRelativePath, fragment)) jarsList.add(normalizedRemoteRelativePath) } else if (normalizedRemoteRelativePath.startsWith(s"pyfiles${File.separator}")) { session.sparkContext.addFile(uri) + sparkContextRelativePaths.add( + (SparkContextResourceType.FILE, normalizedRemoteRelativePath, fragment)) val stringRemotePath = normalizedRemoteRelativePath.toString if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith( ".egg") || stringRemotePath.endsWith(".jar")) { @@ -253,8 +256,12 @@ class ArtifactManager(session: SparkSession) extends Logging { val canonicalUri = fragment.map(Utils.getUriBuilder(new URI(uri)).fragment).getOrElse(new URI(uri)) session.sparkContext.addArchive(canonicalUri.toString) + sparkContextRelativePaths.add( + (SparkContextResourceType.ARCHIVE, normalizedRemoteRelativePath, fragment)) } else if (normalizedRemoteRelativePath.startsWith(s"files${File.separator}")) { session.sparkContext.addFile(uri) + sparkContextRelativePaths.add( + (SparkContextResourceType.FILE, normalizedRemoteRelativePath, fragment)) } } } @@ -293,12 +300,34 @@ class ArtifactManager(session: SparkSession) extends Logging { /** * Returns a [[ClassLoader]] for session-specific jar/class file resources. + * + * Generating class loader is slow, and often [[withResources]] is layered multiple times since + * [[SparkSession.withActive]] is often layered. We have to do some check here to avoid layering + * too many class loaders. Layering too much heavily impacts streaming performance. */ def classloader: ClassLoader = { + val fallbackClassLoader = Utils.getContextOrSparkClassLoader val urls = getAddedJars :+ classDir.toUri.toURL + val thread = Thread.currentThread().toString + val (lastNumberOfJars, classLoaderMap) = numberOfJarsAndCachedClassLoaders.get() + if (lastNumberOfJars != jarsList.size()) { + // If the number of jars has changed, we need to generate new class loaders. + numberOfJarsAndCachedClassLoaders.remove() + buildAndCacheClassLoader(fallbackClassLoader, urls) + } else { + // If the fallback class loader is generated by us, we can reuse it to avoid layering. + if (classLoaderMap.contains(fallbackClassLoader)) { + fallbackClassLoader + } else { + // Otherwise we need to layer one new class loader. + buildAndCacheClassLoader(fallbackClassLoader, urls) + } + } + } + + def buildAndCacheClassLoader(fallbackClassLoader: ClassLoader, urls: Seq[URL]): ClassLoader = { val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES) val userClasspathFirst = SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST) - val fallbackClassLoader = Utils.getContextOrSparkClassLoader val loader = if (prefixes.nonEmpty) { // Two things you need to know about classloader for all of this to make sense: // 1. A classloader needs to be able to fully define a class. @@ -331,23 +360,53 @@ class ArtifactManager(session: SparkSession) extends Logging { } logDebug(s"Using class loader: $loader, containing urls: $urls") + + val cache = numberOfJarsAndCachedClassLoaders.get()._2 + cache += (loader -> true) + numberOfJarsAndCachedClassLoaders.set((jarsList.size(), cache)) loader } private[sql] def clone(newSession: SparkSession): ArtifactManager = { - val newArtifactManager = new ArtifactManager(newSession) - if (artifactPath.toFile.exists()) { - FileUtils.copyDirectory(artifactPath.toFile, newArtifactManager.artifactPath.toFile) - } - val blockManager = session.sparkContext.env.blockManager - val newBlockIds = cachedBlockIdList.asScala.map { blockId => - val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID) - copyBlock(blockId, newBlockId, blockManager) + val sparkContext = session.sparkContext + sparkContext.synchronized { + val newArtifactManager = new ArtifactManager(newSession) + if (artifactPath.toFile.exists()) { + FileUtils.copyDirectory(artifactPath.toFile, newArtifactManager.artifactPath.toFile) + } + val blockManager = sparkContext.env.blockManager + val newBlockIds = cachedBlockIdList.asScala.map { blockId => + val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID) + copyBlock(blockId, newBlockId, blockManager) + } + + // Re-register resources to SparkContext + JobArtifactSet.withActiveJobArtifactState(newArtifactManager.state) { + sparkContextRelativePaths.forEach { case (resourceType, relativePath, fragment) => + val uri = s"${newArtifactManager.artifactURI}/${ + Utils.encodeRelativeUnixPathToURIRawPath( + FilenameUtils.separatorsToUnix(relativePath.toString)) + }" + resourceType match { + case org.apache.spark.sql.artifact.ArtifactManager.SparkContextResourceType.JAR => + sparkContext.addJar(uri) + case org.apache.spark.sql.artifact.ArtifactManager.SparkContextResourceType.FILE => + sparkContext.addFile(uri) + case org.apache.spark.sql.artifact.ArtifactManager.SparkContextResourceType.ARCHIVE => + val canonicalUri = + fragment.map(Utils.getUriBuilder(new URI(uri)).fragment).getOrElse(new URI(uri)) + sparkContext.addArchive(canonicalUri.toString) + case _ => + throw SparkException.internalError(s"Unsupported resource type: $resourceType") + } + } + } + + newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava) + newArtifactManager.jarsList.addAll(jarsList) + newArtifactManager.pythonIncludeList.addAll(pythonIncludeList) + newArtifactManager } - newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava) - newArtifactManager.jarsList.addAll(jarsList) - newArtifactManager.pythonIncludeList.addAll(pythonIncludeList) - newArtifactManager } /** @@ -379,6 +438,8 @@ class ArtifactManager(session: SparkSession) extends Logging { // Clean up artifacts folder FileUtils.deleteDirectory(artifactPath.toFile) + // Clean up internal trackers + numberOfJarsAndCachedClassLoaders.remove() jarsList.clear() pythonIncludeList.clear() cachedBlockIdList.clear() @@ -430,6 +491,11 @@ object ArtifactManager extends Logging { private[artifact] lazy val artifactRootDirectory = Utils.createTempDir(ARTIFACT_DIRECTORY_PREFIX).toPath + private[artifact] object SparkContextResourceType extends Enumeration { + type ResourceType = Value + val JAR, FILE, ARCHIVE = Value + } + private[artifact] def copyBlock( fromId: CacheId, toId: CacheId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 8f030884ad33b..aea8651dcbafe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -223,9 +223,7 @@ abstract class StreamExecution( // To fix call site like "run at :0", we bridge the call site from the caller // thread to this micro batch thread sparkSession.sparkContext.setCallSite(callSite) - JobArtifactSet.withActiveJobArtifactState(jobArtifactState) { - runStream() - } + runStream() } }