Skip to content

Commit 4f7963c

Browse files
markhamstracsd-jenkins
authored andcommitted
Branch 2.2 merge (#232)
* [SPARK-23816][CORE] Killed tasks should ignore FetchFailures. SPARK-19276 ensured that FetchFailures do not get swallowed by other layers of exception handling, but it also meant that a killed task could look like a fetch failure. This is particularly a problem with speculative execution, where we expect to kill tasks as they are reading shuffle data. The fix is to ensure that we always check for killed tasks first. Added a new unit test which fails before the fix, ran it 1k times to check for flakiness. Full suite of tests on jenkins. Author: Imran Rashid <irashid@cloudera.com> Closes apache#20987 from squito/SPARK-23816. (cherry picked from commit 10f45bb) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com> * [SPARK-24007][SQL] EqualNullSafe for FloatType and DoubleType might generate a wrong result by codegen. `EqualNullSafe` for `FloatType` and `DoubleType` might generate a wrong result by codegen. ```scala scala> val df = Seq((Some(-1.0d), None), (None, Some(-1.0d))).toDF() df: org.apache.spark.sql.DataFrame = [_1: double, _2: double] scala> df.show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ scala> df.filter("_1 <=> _2").show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ ``` The result should be empty but the result remains two rows. Added a test. Author: Takuya UESHIN <ueshin@databricks.com> Closes apache#21094 from ueshin/issues/SPARK-24007/equalnullsafe. (cherry picked from commit f09a9e9) Signed-off-by: gatorsmile <gatorsmile@gmail.com> * [SPARK-23963][SQL] Properly handle large number of columns in query on text-based Hive table ## What changes were proposed in this pull request? TableReader would get disproportionately slower as the number of columns in the query increased. I fixed the way TableReader was looking up metadata for each column in the row. Previously, it had been looking up this data in linked lists, accessing each linked list by an index (column number). Now it looks up this data in arrays, where indexing by column number works better. ## How was this patch tested? Manual testing All sbt unit tests python sql tests Author: Bruce Robbins <bersprockets@gmail.com> Closes apache#21043 from bersprockets/tabreadfix. * [MINOR][DOCS] Fix comments of SQLExecution#withExecutionId ## What changes were proposed in this pull request? Fix comment. Change `BroadcastHashJoin.broadcastFuture` to `BroadcastExchangeExec.relationFuture`: https://github.com/apache/spark/blob/d28d5732ae205771f1f443b15b10e64dcffb5ff0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala#L66 ## How was this patch tested? N/A Author: seancxmao <seancxmao@gmail.com> Closes apache#21113 from seancxmao/SPARK-13136. (cherry picked from commit c303b1b) Signed-off-by: hyukjinkwon <gurwls223@apache.org> * [SPARK-23941][MESOS] Mesos task failed on specific spark app name ## What changes were proposed in this pull request? Shell escaped the name passed to spark-submit and change how conf attributes are shell escaped. ## How was this patch tested? This test has been tested manually with Hive-on-spark with mesos or with the use case described in the issue with the sparkPi application with a custom name which contains illegal shell characters. With this PR, hive-on-spark on mesos works like a charm with hive 3.0.0-SNAPSHOT. I state that this contribution is my original work and that I license the work to the project under the project’s open source license Author: Bounkong Khamphousone <bounkong.khamphousone@ebiznext.com> Closes apache#21014 from tiboun/fix/SPARK-23941. (cherry picked from commit 6782359) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com> * [SPARK-23433][CORE] Late zombie task completions update all tasksets Fetch failure lead to multiple tasksets which are active for a given stage. While there is only one "active" version of the taskset, the earlier attempts can still have running tasks, which can complete successfully. So a task completion needs to update every taskset so that it knows the partition is completed. That way the final active taskset does not try to submit another task for the same partition, and so that it knows when it is completed and when it should be marked as a "zombie". Added a regression test. Author: Imran Rashid <irashid@cloudera.com> Closes apache#21131 from squito/SPARK-23433. (cherry picked from commit 94641fe) Signed-off-by: Imran Rashid <irashid@cloudera.com> * [SPARK-23489][SQL][TEST][BRANCH-2.2] HiveExternalCatalogVersionsSuite should verify the downloaded file ## What changes were proposed in this pull request? This is a backport of apache#21210 because `branch-2.2` also faces the same failures. Although [SPARK-22654](https://issues.apache.org/jira/browse/SPARK-22654) made `HiveExternalCatalogVersionsSuite` download from Apache mirrors three times, it has been flaky because it didn't verify the downloaded file. Some Apache mirrors terminate the downloading abnormally, the *corrupted* file shows the following errors. ``` gzip: stdin: not in gzip format tar: Child returned status 1 tar: Error is not recoverable: exiting now 22:46:32.700 WARN org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.hive.HiveExternalCatalogVersionsSuite, thread names: Keep-Alive-Timer ===== *** RUN ABORTED *** java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "/tmp/test-spark/spark-2.2.0"): error=2, No such file or directory ``` This has been reported weirdly in two ways. For example, the above case is reported as Case 2 `no failures`. - Case 1. [Test Result (1 failure / +1)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4389/) - Case 2. [Test Result (no failures)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4811/) This PR aims to make `HiveExternalCatalogVersionsSuite` more robust by verifying the downloaded `tgz` file by extracting and checking the existence of `bin/spark-submit`. If it turns out that the file is empty or corrupted, `HiveExternalCatalogVersionsSuite` will do retry logic like the download failure. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun <dongjoon@apache.org> Closes apache#21232 from dongjoon-hyun/SPARK-23489-2. * [SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly ## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan <wenchen@databricks.com> Closes apache#21229 from cloud-fan/accumulator. (cherry picked from commit 4d5de4d) Signed-off-by: Wenchen Fan <wenchen@databricks.com> * [SPARK-21278][PYSPARK] Upgrade to Py4J 0.10.6 This PR aims to bump Py4J in order to fix the following float/double bug. Py4J 0.10.5 fixes this (py4j/py4j#272) and the latest Py4J is 0.10.6. **BEFORE** ``` >>> df = spark.range(1) >>> df.select(df['id'] + 17.133574204226083).show() +--------------------+ |(id + 17.1335742042)| +--------------------+ | 17.1335742042| +--------------------+ ``` **AFTER** ``` >>> df = spark.range(1) >>> df.select(df['id'] + 17.133574204226083).show() +-------------------------+ |(id + 17.133574204226083)| +-------------------------+ | 17.133574204226083| +-------------------------+ ``` Manual. Author: Dongjoon Hyun <dongjoon@apache.org> Closes apache#18546 from dongjoon-hyun/SPARK-21278. (cherry picked from commit c8d0aba) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com> * [SPARK-16406][SQL] Improve performance of LogicalPlan.resolve `LogicalPlan.resolve(...)` uses linear searches to find an attribute matching a name. This is fine in normal cases, but gets problematic when you try to resolve a large number of columns on a plan with a large number of attributes. This PR adds an indexing structure to `resolve(...)` in order to find potential matches quicker. This PR improves the reference resolution time for the following code by 4x (11.8s -> 2.4s): ``` scala val n = 4000 val values = (1 to n).map(_.toString).mkString(", ") val columns = (1 to n).map("column" + _).mkString(", ") val query = s""" |SELECT $columns |FROM VALUES ($values) T($columns) |WHERE 1=2 AND 1 IN ($columns) |GROUP BY $columns |ORDER BY $columns |""".stripMargin spark.time(sql(query)) ``` Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Closes apache#14083 from hvanhovell/SPARK-16406. * [PYSPARK] Update py4j to version 0.10.7. (cherry picked from commit cc613b5) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com> (cherry picked from commit 323dc3a) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com> * [SPARKR] Match pyspark features in SparkR communication protocol. (cherry picked from commit 628c7b5) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com> (cherry picked from commit 16cd9ac) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com> * Keep old-style messages for AnalysisException with ambiguous references
1 parent cdcacda commit 4f7963c

File tree

49 files changed

+804
-164
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+804
-164
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
263263
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
264264
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
265265
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
266-
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/)
266+
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
267267
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
268268
(BSD licence) sbt and sbt-launch-lib.bash
269269
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)

R/pkg/R/client.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# Creates a SparkR client connection object
2121
# if one doesn't already exist
22-
connectBackend <- function(hostname, port, timeout) {
22+
connectBackend <- function(hostname, port, timeout, authSecret) {
2323
if (exists(".sparkRcon", envir = .sparkREnv)) {
2424
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
2525
cat("SparkRBackend client connection already exists\n")
@@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {
2929

3030
con <- socketConnection(host = hostname, port = port, server = FALSE,
3131
blocking = TRUE, open = "wb", timeout = timeout)
32-
32+
doServerAuth(con, authSecret)
3333
assign(".sparkRCon", con, envir = .sparkREnv)
3434
con
3535
}

R/pkg/R/deserialize.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
6060
stop(paste("Unsupported type for deserialization", type)))
6161
}
6262

63-
readString <- function(con) {
64-
stringLen <- readInt(con)
65-
raw <- readBin(con, raw(), stringLen, endian = "big")
63+
readStringData <- function(con, len) {
64+
raw <- readBin(con, raw(), len, endian = "big")
6665
string <- rawToChar(raw)
6766
Encoding(string) <- "UTF-8"
6867
string
6968
}
7069

70+
readString <- function(con) {
71+
stringLen <- readInt(con)
72+
readStringData(con, stringLen)
73+
}
74+
7175
readInt <- function(con) {
7276
readBin(con, integer(), n = 1, endian = "big")
7377
}

R/pkg/R/sparkR.R

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ sparkR.sparkContext <- function(
161161
" please use the --packages commandline instead", sep = ","))
162162
}
163163
backendPort <- existingPort
164+
authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
165+
if (nchar(authSecret) == 0) {
166+
stop("Auth secret not provided in environment.")
167+
}
164168
} else {
165169
path <- tempfile(pattern = "backend_port")
166170
submitOps <- getClientModeSparkSubmitOpts(
@@ -189,16 +193,27 @@ sparkR.sparkContext <- function(
189193
monitorPort <- readInt(f)
190194
rLibPath <- readString(f)
191195
connectionTimeout <- readInt(f)
196+
197+
# Don't use readString() so that we can provide a useful
198+
# error message if the R and Java versions are mismatched.
199+
authSecretLen = readInt(f)
200+
if (length(authSecretLen) == 0 || authSecretLen == 0) {
201+
stop("Unexpected EOF in JVM connection data. Mismatched versions?")
202+
}
203+
authSecret <- readStringData(f, authSecretLen)
192204
close(f)
193205
file.remove(path)
194206
if (length(backendPort) == 0 || backendPort == 0 ||
195207
length(monitorPort) == 0 || monitorPort == 0 ||
196-
length(rLibPath) != 1) {
208+
length(rLibPath) != 1 || length(authSecret) == 0) {
197209
stop("JVM failed to launch")
198210
}
199-
assign(".monitorConn",
200-
socketConnection(port = monitorPort, timeout = connectionTimeout),
201-
envir = .sparkREnv)
211+
212+
monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
213+
timeout = connectionTimeout, open = "wb")
214+
doServerAuth(monitorConn, authSecret)
215+
216+
assign(".monitorConn", monitorConn, envir = .sparkREnv)
202217
assign(".backendLaunched", 1, envir = .sparkREnv)
203218
if (rLibPath != "") {
204219
assign(".libPath", rLibPath, envir = .sparkREnv)
@@ -208,7 +223,7 @@ sparkR.sparkContext <- function(
208223

209224
.sparkREnv$backendPort <- backendPort
210225
tryCatch({
211-
connectBackend("localhost", backendPort, timeout = connectionTimeout)
226+
connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
212227
},
213228
error = function(err) {
214229
stop("Failed to connect JVM\n")
@@ -632,3 +647,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
632647
NULL
633648
}
634649
}
650+
651+
# Utility function for sending auth data over a socket and checking the server's reply.
652+
doServerAuth <- function(con, authSecret) {
653+
if (nchar(authSecret) == 0) {
654+
stop("Auth secret not provided.")
655+
}
656+
writeString(con, authSecret)
657+
flush(con)
658+
reply <- readString(con)
659+
if (reply != "ok") {
660+
close(con)
661+
stop("Unexpected reply from server.")
662+
}
663+
}

R/pkg/inst/worker/daemon.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))
2828

2929
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
3030
inputCon <- socketConnection(
31-
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
31+
port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)
32+
33+
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
3234

3335
while (TRUE) {
3436
ready <- socketSelect(list(inputCon))

R/pkg/inst/worker/worker.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))
100100

101101
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
102102
inputCon <- socketConnection(
103-
port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
103+
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
104+
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
105+
104106
outputCon <- socketConnection(
105107
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
108+
SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
106109

107110
# read the index of the current partition inside the RDD
108111
partition <- SparkR:::readInt(inputCon)

bin/pyspark

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
2525
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"
2626

2727
# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
28-
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
28+
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
2929
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
3030
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
3131
# and executor Python executables.
3232

3333
# Fail noisily if removed options are set
3434
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
35-
echo "Error in pyspark startup:"
35+
echo "Error in pyspark startup:"
3636
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
3737
exit 1
3838
fi
@@ -57,7 +57,7 @@ export PYSPARK_PYTHON
5757

5858
# Add the PySpark classes to the Python path:
5959
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
60-
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH"
60+
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"
6161

6262
# Load the PySpark shell.py script when ./pyspark is used interactively:
6363
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"

bin/pyspark2.cmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
3030
)
3131

3232
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
33-
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH%
33+
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%
3434

3535
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
3636
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py

core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@
335335
<dependency>
336336
<groupId>net.sf.py4j</groupId>
337337
<artifactId>py4j</artifactId>
338-
<version>0.10.4</version>
338+
<version>0.10.7</version>
339339
</dependency>
340340
<dependency>
341341
<groupId>org.apache.spark</groupId>

core/src/main/scala/org/apache/spark/SecurityManager.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717

1818
package org.apache.spark
1919

20-
import java.lang.{Byte => JByte}
2120
import java.net.{Authenticator, PasswordAuthentication}
22-
import java.security.{KeyStore, SecureRandom}
21+
import java.security.KeyStore
2322
import java.security.cert.X509Certificate
2423
import javax.net.ssl._
2524

26-
import com.google.common.hash.HashCodes
2725
import com.google.common.io.Files
2826
import org.apache.hadoop.io.Text
2927

@@ -435,12 +433,7 @@ private[spark] class SecurityManager(
435433
val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY)
436434
if (secretKey == null || secretKey.length == 0) {
437435
logDebug("generateSecretKey: yarn mode, secret key from credentials is null")
438-
val rnd = new SecureRandom()
439-
val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
440-
val secret = new Array[Byte](length)
441-
rnd.nextBytes(secret)
442-
443-
val cookie = HashCodes.fromBytes(secret).toString()
436+
val cookie = Utils.createSecret(sparkConf)
444437
SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie)
445438
cookie
446439
} else {

core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,39 @@
1717

1818
package org.apache.spark.api.python
1919

20-
import java.io.DataOutputStream
21-
import java.net.Socket
20+
import java.io.{DataOutputStream, File, FileOutputStream}
21+
import java.net.InetAddress
22+
import java.nio.charset.StandardCharsets.UTF_8
23+
import java.nio.file.Files
2224

2325
import py4j.GatewayServer
2426

27+
import org.apache.spark.SparkConf
2528
import org.apache.spark.internal.Logging
2629
import org.apache.spark.util.Utils
2730

2831
/**
29-
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
30-
* back to its caller via a callback port specified by the caller.
32+
* Process that starts a Py4J GatewayServer on an ephemeral port.
3133
*
3234
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
3335
*/
3436
private[spark] object PythonGatewayServer extends Logging {
3537
initializeLogIfNecessary(true)
3638

37-
def main(args: Array[String]): Unit = Utils.tryOrExit {
38-
// Start a GatewayServer on an ephemeral port
39-
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
39+
def main(args: Array[String]): Unit = {
40+
val secret = Utils.createSecret(new SparkConf())
41+
42+
// Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
43+
// with the same secret, in case the app needs callbacks from the JVM to the underlying
44+
// python processes.
45+
val localhost = InetAddress.getLoopbackAddress()
46+
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
47+
.authToken(secret)
48+
.javaPort(0)
49+
.javaAddress(localhost)
50+
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
51+
.build()
52+
4053
gatewayServer.start()
4154
val boundPort: Int = gatewayServer.getListeningPort
4255
if (boundPort == -1) {
@@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
4659
logDebug(s"Started PythonGatewayServer on port $boundPort")
4760
}
4861

49-
// Communicate the bound port back to the caller via the caller-specified callback port
50-
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
51-
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
52-
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
53-
val callbackSocket = new Socket(callbackHost, callbackPort)
54-
val dos = new DataOutputStream(callbackSocket.getOutputStream)
62+
// Communicate the connection information back to the python process by writing the
63+
// information in the requested file. This needs to match the read side in java_gateway.py.
64+
val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
65+
val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
66+
"connection", ".info").toFile()
67+
68+
val dos = new DataOutputStream(new FileOutputStream(tmpPath))
5569
dos.writeInt(boundPort)
70+
71+
val secretBytes = secret.getBytes(UTF_8)
72+
dos.writeInt(secretBytes.length)
73+
dos.write(secretBytes, 0, secretBytes.length)
5674
dos.close()
57-
callbackSocket.close()
75+
76+
if (!tmpPath.renameTo(connectionInfoPath)) {
77+
logError(s"Unable to write connection information to $connectionInfoPath.")
78+
System.exit(1)
79+
}
5880

5981
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
6082
while (System.in.read() != -1) {

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
3838
import org.apache.spark.input.PortableDataStream
3939
import org.apache.spark.internal.Logging
4040
import org.apache.spark.rdd.RDD
41+
import org.apache.spark.security.SocketAuthHelper
4142
import org.apache.spark.util._
4243

4344

@@ -421,6 +422,12 @@ private[spark] object PythonRDD extends Logging {
421422
// remember the broadcasts sent to each worker
422423
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
423424

425+
// Authentication helper used when serving iterator data.
426+
private lazy val authHelper = {
427+
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
428+
new SocketAuthHelper(conf)
429+
}
430+
424431
def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
425432
synchronized {
426433
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
@@ -443,12 +450,13 @@ private[spark] object PythonRDD extends Logging {
443450
* (effectively a collect()), but allows you to run on a certain subset of partitions,
444451
* or to enable local execution.
445452
*
446-
* @return the port number of a local socket which serves the data collected from this job.
453+
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
454+
* data collected from this job, and the secret for authentication.
447455
*/
448456
def runJob(
449457
sc: SparkContext,
450458
rdd: JavaRDD[Array[Byte]],
451-
partitions: JArrayList[Int]): Int = {
459+
partitions: JArrayList[Int]): Array[Any] = {
452460
type ByteArray = Array[Byte]
453461
type UnrolledPartition = Array[ByteArray]
454462
val allPartitions: Array[UnrolledPartition] =
@@ -461,13 +469,14 @@ private[spark] object PythonRDD extends Logging {
461469
/**
462470
* A helper function to collect an RDD as an iterator, then serve it via socket.
463471
*
464-
* @return the port number of a local socket which serves the data collected from this job.
472+
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
473+
* data collected from this job, and the secret for authentication.
465474
*/
466-
def collectAndServe[T](rdd: RDD[T]): Int = {
475+
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
467476
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
468477
}
469478

470-
def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
479+
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
471480
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
472481
}
473482

@@ -698,8 +707,11 @@ private[spark] object PythonRDD extends Logging {
698707
* and send them into this connection.
699708
*
700709
* The thread will terminate after all the data are sent or any exceptions happen.
710+
*
711+
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
712+
* data collected from this job, and the secret for authentication.
701713
*/
702-
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
714+
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
703715
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
704716
// Close the socket if no connection in 15 seconds
705717
serverSocket.setSoTimeout(15000)
@@ -709,11 +721,14 @@ private[spark] object PythonRDD extends Logging {
709721
override def run() {
710722
try {
711723
val sock = serverSocket.accept()
724+
authHelper.authClient(sock)
725+
712726
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
713727
Utils.tryWithSafeFinally {
714728
writeIteratorToStream(items, out)
715729
} {
716730
out.close()
731+
sock.close()
717732
}
718733
} catch {
719734
case NonFatal(e) =>
@@ -724,7 +739,7 @@ private[spark] object PythonRDD extends Logging {
724739
}
725740
}.start()
726741

727-
serverSocket.getLocalPort
742+
Array(serverSocket.getLocalPort, authHelper.secret)
728743
}
729744

730745
private def getMergedConf(confAsMap: java.util.HashMap[String, String],

core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ private[spark] object PythonUtils {
3232
val pythonPath = new ArrayBuffer[String]
3333
for (sparkHome <- sys.env.get("SPARK_HOME")) {
3434
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
35-
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator)
35+
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
3636
}
3737
pythonPath ++= SparkContext.jarOfObject(this)
3838
pythonPath.mkString(File.pathSeparator)

0 commit comments

Comments
 (0)