Skip to content

Commit 5a9befa

Browse files
committed
Disallow to create SparkContext in executors.
1 parent 42f01e3 commit 5a9befa

File tree

5 files changed

+50
-1
lines changed

5 files changed

+50
-1
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ class SparkContext(config: SparkConf) extends Logging {
8383
// The call site where this SparkContext was constructed.
8484
private val creationSite: CallSite = Utils.getCallSite()
8585

86+
// In order to prevent SparkContext from being created in executors.
87+
SparkContext.assertOnDriver()
88+
8689
// In order to prevent multiple SparkContexts from being active at the same time, mark this
8790
// context as having started construction.
8891
// NOTE: this must be placed at the beginning of the SparkContext constructor.
@@ -2554,6 +2557,19 @@ object SparkContext extends Logging {
25542557
}
25552558
}
25562559

2560+
/**
2561+
* Called to ensure that SparkContext is created or accessed only on the Driver.
2562+
*
2563+
* Throws an exception if a SparkContext is about to be created in executors.
2564+
*/
2565+
private[spark] def assertOnDriver(): Unit = {
2566+
if (TaskContext.get != null) {
2567+
// we're accessing it during task execution, fail.
2568+
throw new IllegalStateException(
2569+
"SparkContext should only be created and accessed on the driver.")
2570+
}
2571+
}
2572+
25572573
/**
25582574
* This function may be used to get or instantiate a SparkContext and register it as a
25592575
* singleton object. Because we can only have one active SparkContext per JVM,

core/src/test/scala/org/apache/spark/SparkContextSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,18 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
934934
}
935935
}
936936
}
937+
938+
test("Disallow to create SparkContext in executors") {
939+
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]"))
940+
941+
val error = intercept[SparkException] {
942+
sc.range(0, 1).foreach { _ =>
943+
new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
944+
}
945+
}.getMessage()
946+
947+
assert(error.contains("SparkContext should only be created and accessed on the driver."))
948+
}
937949
}
938950

939951
object SparkContextSuite {

python/pyspark/context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pyspark.storagelevel import StorageLevel
3939
from pyspark.resource.information import ResourceInformation
4040
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
41+
from pyspark.taskcontext import TaskContext
4142
from pyspark.traceback_utils import CallSite, first_spark_call
4243
from pyspark.status import StatusTracker
4344
from pyspark.profiler import ProfilerCollector, BasicProfiler
@@ -118,6 +119,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
118119
...
119120
ValueError:...
120121
"""
122+
# In order to prevent SparkContext from being created in executors.
123+
SparkContext._assert_on_driver()
124+
121125
self._callsite = first_spark_call() or CallSite(None, None, None)
122126
if gateway is not None and gateway.gateway_parameters.auth_token is None:
123127
raise ValueError(
@@ -1145,6 +1149,16 @@ def resources(self):
11451149
resources[name] = ResourceInformation(name, addrs)
11461150
return resources
11471151

1152+
@staticmethod
1153+
def _assert_on_driver():
1154+
"""
1155+
Called to ensure that SparkContext is created only on the Driver.
1156+
1157+
Throws an exception if a SparkContext is about to be created in executors.
1158+
"""
1159+
if TaskContext.get() is not None:
1160+
raise Exception("SparkContext should only be created and accessed on the driver.")
1161+
11481162

11491163
def _test():
11501164
import atexit

python/pyspark/tests/test_context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,13 @@ def test_resources(self):
267267
resources = sc.resources
268268
self.assertEqual(len(resources), 0)
269269

270+
def test_disallow_to_create_spark_context_in_executors(self):
271+
with SparkContext("local-cluster[3, 1, 1024]") as sc:
272+
with self.assertRaises(Exception) as context:
273+
sc.range(2).foreach(lambda _: SparkContext())
274+
self.assertIn("SparkContext should only be created and accessed on the driver.",
275+
str(context.exception))
276+
270277

271278
class ContextTestsWithResources(unittest.TestCase):
272279

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ object SparkSession extends Logging {
10871087
}
10881088

10891089
private def assertOnDriver(): Unit = {
1090-
if (Utils.isTesting && TaskContext.get != null) {
1090+
if (TaskContext.get != null) {
10911091
// we're accessing it during task execution, fail.
10921092
throw new IllegalStateException(
10931093
"SparkSession should only be created and accessed on the driver.")

0 commit comments

Comments
 (0)