Skip to content

Commit

Permalink
optimizzzzze
Browse files Browse the repository at this point in the history
  • Loading branch information
xupefei committed Oct 11, 2024
1 parent d8ec1d3 commit 3bcda6d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

Expand Down Expand Up @@ -72,9 +73,11 @@ class ArtifactManager(session: SparkSession) extends Logging {
(ArtifactUtils.concatenatePaths(artifactPath, "classes"),
s"$artifactURI${File.separator}classes${File.separator}")

protected[sql] val state: JobArtifactState = {
protected[sql] lazy val state: JobArtifactState = {
val sessionIsolated = session.conf.get("spark.session.isolate.artifacts", "true")
val replIsolated = session.conf.get("spark.repl.isolate.artifacts", "false")
logInfo(s"Session isolation: $sessionIsolated, REPL isolation: $replIsolated." +
s" Session UUID: ${session.sessionUUID}, REPL class URI: $replClassURI.")
(sessionIsolated, replIsolated) match {
case ("true", "true") => JobArtifactState(session.sessionUUID, Some(replClassURI))
case ("true", "false") => JobArtifactState(session.sessionUUID, None)
Expand All @@ -86,29 +89,23 @@ class ArtifactManager(session: SparkSession) extends Logging {

private var initialContextResourcesCopied = false

/** The number of JARs added at the time of the last [[withResources]] call. */
private val lastNumberOfJarsInScope = new ThreadLocal[Int] {
override def initialValue(): Int = -1
}
/**
* The number of JARs for the last [[classloader]] call, and all generated class loaders.
* Value of the HashMap is not used.
*/
private val numberOfJarsAndCachedClassLoaders =
new ThreadLocal[(Int, mutable.WeakHashMap[ClassLoader, Boolean])] {
override def initialValue(): (Int, mutable.WeakHashMap[ClassLoader, Boolean]) =
(-1, new mutable.WeakHashMap[ClassLoader, Boolean]())
}

def withResources[T](f: => T): T = {
// Generating class loader is slow, and often withResources is layered multiple times since
// SparkSession.withActive is often layered. We have to do some check here to avoid layering
// too many class loaders. Layering too much heavily impacts streaming performance.
// jarsList is append-only, so we can use its size to decide whether we need a new layer.
val prevNumOfJars = lastNumberOfJarsInScope.get()
if (prevNumOfJars == jarsList.size()) {
try f finally { lastNumberOfJarsInScope.set(prevNumOfJars) }
} else {
lastNumberOfJarsInScope.set(jarsList.size())

Utils.withContextClassLoader(classloader, retainChange = true) {
JobArtifactSet.withActiveJobArtifactState(state) {
// Copy over global initial resources to this session. Often used by spark-submit.
copyInitialContextResourcesIfNeeded()
Utils.withContextClassLoader(classloader, retainChange = true) {
JobArtifactSet.withActiveJobArtifactState(state) {
// Copy over global initial resources to this session. Often used by spark-submit.
copyInitialContextResourcesIfNeeded()

try f finally { lastNumberOfJarsInScope.set(prevNumOfJars) }
}
f
}
}
}
Expand All @@ -133,6 +130,8 @@ class ArtifactManager(session: SparkSession) extends Logging {
protected val cachedBlockIdList = new CopyOnWriteArrayList[CacheId]
protected val jarsList = new CopyOnWriteArrayList[Path]
protected val pythonIncludeList = new CopyOnWriteArrayList[String]
protected val sparkContextRelativePaths =
new CopyOnWriteArrayList[(SparkContextResourceType.ResourceType, Path, Option[String])]

/**
* Get the URLs of all jar artifacts.
Expand Down Expand Up @@ -241,9 +240,13 @@ class ArtifactManager(session: SparkSession) extends Logging {

if (normalizedRemoteRelativePath.startsWith(s"jars${File.separator}")) {
session.sparkContext.addJar(uri)
sparkContextRelativePaths.add(
(SparkContextResourceType.JAR, normalizedRemoteRelativePath, fragment))
jarsList.add(normalizedRemoteRelativePath)
} else if (normalizedRemoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
session.sparkContext.addFile(uri)
sparkContextRelativePaths.add(
(SparkContextResourceType.FILE, normalizedRemoteRelativePath, fragment))
val stringRemotePath = normalizedRemoteRelativePath.toString
if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
".egg") || stringRemotePath.endsWith(".jar")) {
Expand All @@ -253,8 +256,12 @@ class ArtifactManager(session: SparkSession) extends Logging {
val canonicalUri =
fragment.map(Utils.getUriBuilder(new URI(uri)).fragment).getOrElse(new URI(uri))
session.sparkContext.addArchive(canonicalUri.toString)
sparkContextRelativePaths.add(
(SparkContextResourceType.ARCHIVE, normalizedRemoteRelativePath, fragment))
} else if (normalizedRemoteRelativePath.startsWith(s"files${File.separator}")) {
session.sparkContext.addFile(uri)
sparkContextRelativePaths.add(
(SparkContextResourceType.FILE, normalizedRemoteRelativePath, fragment))
}
}
}
Expand Down Expand Up @@ -293,12 +300,34 @@ class ArtifactManager(session: SparkSession) extends Logging {

/**
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*
* Generating class loader is slow, and often [[withResources]] is layered multiple times since
* [[SparkSession.withActive]] is often layered. We have to do some check here to avoid layering
* too many class loaders. Layering too much heavily impacts streaming performance.
*/
def classloader: ClassLoader = {
val fallbackClassLoader = Utils.getContextOrSparkClassLoader
val urls = getAddedJars :+ classDir.toUri.toURL
val thread = Thread.currentThread().toString
val (lastNumberOfJars, classLoaderMap) = numberOfJarsAndCachedClassLoaders.get()
if (lastNumberOfJars != jarsList.size()) {
// If the number of jars has changed, we need to generate new class loaders.
numberOfJarsAndCachedClassLoaders.remove()
buildAndCacheClassLoader(fallbackClassLoader, urls)
} else {
// If the fallback class loader is generated by us, we can reuse it to avoid layering.
if (classLoaderMap.contains(fallbackClassLoader)) {
fallbackClassLoader
} else {
// Otherwise we need to layer one new class loader.
buildAndCacheClassLoader(fallbackClassLoader, urls)
}
}
}

def buildAndCacheClassLoader(fallbackClassLoader: ClassLoader, urls: Seq[URL]): ClassLoader = {
val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)
val userClasspathFirst = SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
val fallbackClassLoader = Utils.getContextOrSparkClassLoader
val loader = if (prefixes.nonEmpty) {
// Two things you need to know about classloader for all of this to make sense:
// 1. A classloader needs to be able to fully define a class.
Expand Down Expand Up @@ -331,23 +360,53 @@ class ArtifactManager(session: SparkSession) extends Logging {
}

logDebug(s"Using class loader: $loader, containing urls: $urls")

val cache = numberOfJarsAndCachedClassLoaders.get()._2
cache += (loader -> true)
numberOfJarsAndCachedClassLoaders.set((jarsList.size(), cache))
loader
}

private[sql] def clone(newSession: SparkSession): ArtifactManager = {
val newArtifactManager = new ArtifactManager(newSession)
if (artifactPath.toFile.exists()) {
FileUtils.copyDirectory(artifactPath.toFile, newArtifactManager.artifactPath.toFile)
}
val blockManager = session.sparkContext.env.blockManager
val newBlockIds = cachedBlockIdList.asScala.map { blockId =>
val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID)
copyBlock(blockId, newBlockId, blockManager)
val sparkContext = session.sparkContext
sparkContext.synchronized {
val newArtifactManager = new ArtifactManager(newSession)
if (artifactPath.toFile.exists()) {
FileUtils.copyDirectory(artifactPath.toFile, newArtifactManager.artifactPath.toFile)
}
val blockManager = sparkContext.env.blockManager
val newBlockIds = cachedBlockIdList.asScala.map { blockId =>
val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID)
copyBlock(blockId, newBlockId, blockManager)
}

// Re-register resources to SparkContext
JobArtifactSet.withActiveJobArtifactState(newArtifactManager.state) {
sparkContextRelativePaths.forEach { case (resourceType, relativePath, fragment) =>
val uri = s"${newArtifactManager.artifactURI}/${
Utils.encodeRelativeUnixPathToURIRawPath(
FilenameUtils.separatorsToUnix(relativePath.toString))
}"
resourceType match {
case org.apache.spark.sql.artifact.ArtifactManager.SparkContextResourceType.JAR =>
sparkContext.addJar(uri)
case org.apache.spark.sql.artifact.ArtifactManager.SparkContextResourceType.FILE =>
sparkContext.addFile(uri)
case org.apache.spark.sql.artifact.ArtifactManager.SparkContextResourceType.ARCHIVE =>
val canonicalUri =
fragment.map(Utils.getUriBuilder(new URI(uri)).fragment).getOrElse(new URI(uri))
sparkContext.addArchive(canonicalUri.toString)
case _ =>
throw SparkException.internalError(s"Unsupported resource type: $resourceType")
}
}
}

newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava)
newArtifactManager.jarsList.addAll(jarsList)
newArtifactManager.pythonIncludeList.addAll(pythonIncludeList)
newArtifactManager
}
newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava)
newArtifactManager.jarsList.addAll(jarsList)
newArtifactManager.pythonIncludeList.addAll(pythonIncludeList)
newArtifactManager
}

/**
Expand Down Expand Up @@ -379,6 +438,8 @@ class ArtifactManager(session: SparkSession) extends Logging {
// Clean up artifacts folder
FileUtils.deleteDirectory(artifactPath.toFile)

// Clean up internal trackers
numberOfJarsAndCachedClassLoaders.remove()
jarsList.clear()
pythonIncludeList.clear()
cachedBlockIdList.clear()
Expand Down Expand Up @@ -430,6 +491,11 @@ object ArtifactManager extends Logging {
private[artifact] lazy val artifactRootDirectory =
Utils.createTempDir(ARTIFACT_DIRECTORY_PREFIX).toPath

private[artifact] object SparkContextResourceType extends Enumeration {
type ResourceType = Value
val JAR, FILE, ARCHIVE = Value
}

private[artifact] def copyBlock(
fromId: CacheId,
toId: CacheId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ abstract class StreamExecution(
// To fix call site like "run at <unknown>:0", we bridge the call site from the caller
// thread to this micro batch thread
sparkSession.sparkContext.setCallSite(callSite)
JobArtifactSet.withActiveJobArtifactState(jobArtifactState) {
runStream()
}
runStream()
}
}

Expand Down

0 comments on commit 3bcda6d

Please sign in to comment.