Skip to content

Commit 320694c

Browse files
Raghu AngadiHyukjinKwon
authored andcommitted
[SPARK-45245][PYTHON][CONNECT] PythonWorkerFactory: Timeout if worker does not connect back
### What changes were proposed in this pull request? `createSimpleWorker()` method in `PythonWorkerFactory` waits forever if the worker fails to connect back to the server. This is because it calls accept() without a timeout. If the worker does not connect back, accept() waits forever. There is supposed to be 10 seconds timeout, but it was not implemented correctly. This PR adds a 10 second timeout. ### Why are the changes needed? Otherwise create method could be stuck forever. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Unit test - Manual ### Was this patch authored or co-authored using generative AI tooling? Generated-by: ChatGPT 4.0 Asked ChatGPT to generate sample code to do non-blocking accept() on a socket channel in Java. Closes #43023 from rangadi/fix-py-accept. Authored-by: Raghu Angadi <raghu.angadi@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 2ad79e1 commit 320694c

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
1919

2020
import java.io.{DataInputStream, DataOutputStream, EOFException, File, InputStream}
2121
import java.net.{InetAddress, InetSocketAddress, SocketException}
22+
import java.net.SocketTimeoutException
2223
import java.nio.channels._
2324
import java.util.Arrays
2425
import java.util.concurrent.TimeUnit
@@ -184,10 +185,18 @@ private[spark] class PythonWorkerFactory(
184185
redirectStreamsToStderr(workerProcess.getInputStream, workerProcess.getErrorStream)
185186

186187
// Wait for it to connect to our socket, and validate the auth secret.
187-
serverSocketChannel.socket().setSoTimeout(10000)
188-
189188
try {
190-
val socketChannel = serverSocketChannel.accept()
189+
// Wait up to 10 seconds for client to connect.
190+
serverSocketChannel.configureBlocking(false)
191+
val serverSelector = Selector.open()
192+
serverSocketChannel.register(serverSelector, SelectionKey.OP_ACCEPT)
193+
val socketChannel =
194+
if (serverSelector.select(10 * 1000) > 0) { // Wait up to 10 seconds.
195+
serverSocketChannel.accept()
196+
} else {
197+
throw new SocketTimeoutException(
198+
"Timed out while waiting for the Python worker to connect back")
199+
}
191200
authHelper.authClient(socketChannel.socket())
192201
val pid = workerProcess.toHandle.pid()
193202
if (pid < 0) {
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.python
19+
20+
import java.net.SocketTimeoutException
21+
22+
// scalastyle:off executioncontextglobal
23+
import scala.concurrent.ExecutionContext.Implicits.global
24+
// scalastyle:on executioncontextglobal
25+
import scala.concurrent.Future
26+
import scala.concurrent.duration._
27+
28+
import org.scalatest.matchers.must.Matchers
29+
30+
import org.apache.spark.SharedSparkContext
31+
import org.apache.spark.SparkException
32+
import org.apache.spark.SparkFunSuite
33+
import org.apache.spark.util.ThreadUtils
34+
35+
// Tests for PythonWorkerFactory.
36+
class PythonWorkerFactorySuite extends SparkFunSuite with Matchers with SharedSparkContext {
37+
38+
test("createSimpleWorker() fails with a timeout error if worker does not connect back") {
39+
// It verifies that server side times out in accept(), if the worker does not connect back.
40+
// E.g. the worker might fail at the beginning before it tries to connect back.
41+
42+
val workerFactory = new PythonWorkerFactory(
43+
"python3", "pyspark.testing.non_existing_worker_module", Map.empty
44+
)
45+
46+
// Create the worker in a separate thread so that if there is a bug where it does not
47+
// return (accept() used to be blocking), the test doesn't hang for a long time.
48+
val createFuture = Future {
49+
val ex = intercept[SparkException] {
50+
workerFactory.createSimpleWorker(blockingMode = true) // blockingMode doesn't matter.
51+
// NOTE: This takes 10 seconds (which is the accept timeout in PythonWorkerFactory).
52+
// That makes this a bit longish test.
53+
}
54+
assert(ex.getMessage.contains("Python worker failed to connect back"))
55+
assert(ex.getCause.isInstanceOf[SocketTimeoutException])
56+
}
57+
58+
// Timeout ensures that the test fails in 5 minutes if createSimplerWorker() doesn't return.
59+
ThreadUtils.awaitReady(createFuture, 5.minutes)
60+
}
61+
}

0 commit comments

Comments
 (0)