Skip to content

Commit

Permalink
[SPARK-49418][CONNECT][SQL] Shared Session Thread Locals
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR adds interfaces for SparkSession Thread Locals.

### Why are the changes needed?
We are creating a unified Spark SQL Scala interface. This is part of that effort.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48374 from hvanhovell/SPARK-49418.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Oct 9, 2024
1 parent 52538f0 commit b565a8d
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 214 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -525,6 +525,8 @@ class SparkSession private[sql] (
}
}

override private[sql] def isUsable: Boolean = client.isSessionValid

implicit class RichColumn(c: Column) {
def expr: proto.Expression = toExpr(c)
def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
Expand All @@ -533,7 +535,9 @@ class SparkSession private[sql] (

// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends api.SparkSessionCompanion with Logging {
object SparkSession extends api.BaseSparkSessionCompanion with Logging {
override private[sql] type Session = SparkSession

private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
Expand All @@ -549,29 +553,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
override def load(c: Configuration): SparkSession = create(c)
})

/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]

/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
}

/**
* Create a new Spark Connect server to connect locally.
*/
Expand Down Expand Up @@ -624,17 +605,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
new SparkSession(configuration.toSparkConnectClient, planIdGenerator)
}

/**
* Hook called when a session is closed.
*/
private[sql] def onSessionClose(session: SparkSession): Unit = {
sessions.invalidate(session.client.configuration)
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}

/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
Expand Down Expand Up @@ -781,71 +751,12 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
}
}

/**
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] =
Option(defaultSession.get()).filter(_.client.isSessionValid)

/**
* Sets the default SparkSession.
*
* @since 3.5.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}

/**
* Clears the default SparkSession.
*
* @since 3.5.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}

/**
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get()).filter(_.client.isSessionValid)

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
*
* @since 3.5.0
*/
def setActiveSession(session: SparkSession): Unit = {
activeThreadSession.set(session)
}
/** @inheritdoc */
override def getActiveSession: Option[SparkSession] = super.getActiveSession

/**
* Clears the active SparkSession for current thread.
*
* @since 3.5.0
*/
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}
/** @inheritdoc */
override def getDefaultSession: Option[SparkSession] = super.getDefaultSession

/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 3.5.0
*/
def active: SparkSession = {
getActiveSession
.orElse(getDefaultSession)
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
}
/** @inheritdoc */
override def active: SparkSession = super.active
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.util.control.NonFatal

import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor}

import org.apache.spark.SparkException
import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.util.SparkSerDeUtils

Expand Down Expand Up @@ -113,7 +114,7 @@ class SparkSessionSuite extends ConnectFunSuite {
SparkSession.clearActiveSession()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.isEmpty)
intercept[IllegalStateException](SparkSession.active)
intercept[SparkException](SparkSession.active)

// Create a session
val session1 = SparkSession.builder().remote(connectionString1).getOrCreate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.SparkSession.baseRelationToDataFrame"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.canUseSession"),

// SparkSession#implicits
ProblemFilters.exclude[DirectMissingMethodProblem](
Expand Down
6 changes: 6 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.javalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed$"),

// SPARK-49418: Consolidate thread local handling in sql/api
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setActiveSession"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setDefaultSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearActiveSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearDefaultSession"),
) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++
loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++
loggingExcludes("org.apache.spark.sql.SparkSession#Builder")
Expand Down
Loading

0 comments on commit b565a8d

Please sign in to comment.