Skip to content

Commit

Permalink
[SPARK-47986][CONNECT][FOLLOW-UP] Unable to create a new session when…
Browse files Browse the repository at this point in the history
… the default session is closed by the server

### What changes were proposed in this pull request?

This is a Scala port of apache#46221 and apache#46435.

A client is unaware of a server restart or the server having closed the client until it receives an error. However, at this point, the client in unable to create a new session to the same connect endpoint, since the stale session is still recorded
as the active and default session.

With this change, when the server communicates that the session has changed via a GRPC error, the session and the respective client are marked as stale, thereby allowing a new default connection can be created via the session builder.

In some cases, particularly when running older versions of the Spark cluster (3.5), the error actually manifests as a mismatch in the observed server-side session id between calls. With this fix, we also capture this case and ensure that this case is
also handled.

### Why are the changes needed?

Being unable to use getOrCreate() after an error is unacceptable and should be fixed.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

./build/sbt testOnly *SparkSessionE2ESuite

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#47008 from changgyoopark-db/SPARK-47986.

Authored-by: Changgyoo Park <changgyoo.park@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
changgyoopark-db authored and HyukjinKwon committed Jun 19, 2024
1 parent 5458763 commit 5d6e9dd
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -829,10 +829,16 @@ object SparkSession extends Logging {

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet.
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
defaultSession.compareAndSet(null, session)
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
Expand Down Expand Up @@ -972,7 +978,7 @@ object SparkSession extends Logging {
def appName(name: String): Builder = this

private def tryCreateSessionFromClient(): Option[SparkSession] = {
if (client != null) {
if (client != null && client.isSessionValid) {
Option(new SparkSession(client, planIdGenerator))
} else {
None
Expand Down Expand Up @@ -1024,19 +1030,30 @@ object SparkSession extends Logging {
*/
def getOrCreate(): SparkSession = {
val session = tryCreateSessionFromClient()
.getOrElse(sessions.get(builder.configuration))
.getOrElse({
var existingSession = sessions.get(builder.configuration)
if (!existingSession.client.isSessionValid) {
// If the cached session has become invalid, e.g., due to a server restart, the cache
// entry is invalidated.
sessions.invalidate(builder.configuration)
existingSession = sessions.get(builder.configuration)
}
existingSession
})
setDefaultAndActiveSession(session)
applyOptions(session)
session
}
}

/**
* Returns the default SparkSession.
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())
def getDefaultSession: Option[SparkSession] =
Option(defaultSession.get()).filter(_.client.isSessionValid)

/**
* Sets the default SparkSession.
Expand All @@ -1057,11 +1074,13 @@ object SparkSession extends Logging {
}

/**
* Returns the active SparkSession for the current thread.
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get())
def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get()).filter(_.client.isSessionValid)

/**
* Changes the SparkSession that will be returned in this thread and its children when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
.create()
}
}

test("SPARK-47986: get or create after session changed") {
val remote = s"sc://localhost:$serverPort"

SparkSession.clearDefaultSession()
SparkSession.clearActiveSession()

val session1 = SparkSession
.builder()
.remote(remote)
.getOrCreate()

assert(session1 eq SparkSession.getActiveSession.get)
assert(session1 eq SparkSession.getDefaultSession.get)
assert(session1.range(3).collect().length == 3)

session1.client.hijackServerSideSessionIdForTesting("-testing")

val e = intercept[SparkException] {
session1.range(3).analyze
}

assert(e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]"))
assert(!session1.client.isSessionValid)
assert(SparkSession.getActiveSession.isEmpty)
assert(SparkSession.getDefaultSession.isEmpty)

val session2 = SparkSession
.builder()
.remote(remote)
.getOrCreate()

assert(session1 ne session2)
assert(session2.client.isSessionValid)
assert(session2 eq SparkSession.getActiveSession.get)
assert(session2 eq SparkSession.getDefaultSession.get)
assert(session2.range(3).collect().length == 3)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
*/
package org.apache.spark.sql.connect.client

import java.util.concurrent.atomic.AtomicBoolean

import com.google.protobuf.GeneratedMessageV3
import io.grpc.{Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging
Expand All @@ -30,6 +33,12 @@ class ResponseValidator extends Logging {
// do not use server-side streaming.
private var serverSideSessionId: Option[String] = None

// Indicates whether the client and the client information on the server correspond to each other
// This flag being false means that the server has restarted and lost the client information, or
// there is a logic error in the code; both cases, the user should establish a new connection to
// the server. Access to the value has to be synchronized since it can be shared.
private val isSessionActive: AtomicBoolean = new AtomicBoolean(true)

// Returns the server side session ID, used to send it back to the server in the follow-up
// requests so the server can validate it session id against the previous requests.
def getServerSideSessionId: Option[String] = serverSideSessionId
Expand All @@ -42,8 +51,25 @@ class ResponseValidator extends Logging {
serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
}

/**
* Returns true if the session is valid on both the client and the server.
*/
private[sql] def isSessionValid: Boolean = {
// An active session is considered valid.
isSessionActive.getAcquire
}

def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
val response = fn
val response =
try {
fn
} catch {
case e: StatusRuntimeException
if e.getStatus.getCode == Status.Code.INTERNAL &&
e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]") =>
isSessionActive.setRelease(false)
throw e
}
val field = response.getDescriptorForType.findFieldByName("server_side_session_id")
// If the field does not exist, we ignore it. New / Old message might not contain it and this
// behavior allows us to be compatible.
Expand All @@ -54,6 +80,7 @@ class ResponseValidator extends Logging {
serverSideSessionId match {
case Some(id) =>
if (value != id) {
isSessionActive.setRelease(false)
throw new IllegalStateException(
s"Server side session ID changed from $id to $value")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ private[sql] class SparkConnectClient(
stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
}

/**
* Returns true if the session is valid on both the client and the server. A session becomes
* invalid if the server side information about the client, e.g., session ID, does not
* correspond to the actual client state.
*/
private[sql] def isSessionValid: Boolean = {
// The last known state of the session is store in `responseValidator`, because it is where the
// client gets responses from the server.
stubState.responseValidator.isSessionValid
}

private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
Expand Down

0 comments on commit 5d6e9dd

Please sign in to comment.