Skip to content

Commit 7b3bb13

Browse files
committed
Address review comments, cosmetic cleanups.
1 parent 10ba6e1 commit 7b3bb13

File tree

3 files changed

+23
-30
lines changed

3 files changed

+23
-30
lines changed

python/pyspark/context.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import sys
2121
from threading import Lock
2222
from tempfile import NamedTemporaryFile
23-
from collections import namedtuple
2423

2524
from pyspark import accumulators
2625
from pyspark.accumulators import Accumulator
@@ -33,7 +32,7 @@
3332
from pyspark.storagelevel import StorageLevel
3433
from pyspark import rdd
3534
from pyspark.rdd import RDD
36-
from pyspark.traceback_utils import extract_concise_traceback
35+
from pyspark.traceback_utils import CallSite, first_spark_call
3736

3837
from py4j.java_collections import ListConverter
3938

@@ -100,11 +99,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
10099
...
101100
ValueError:...
102101
"""
103-
if extract_concise_traceback() is not None:
104-
self._callsite = extract_concise_traceback()
105-
else:
106-
tempNamedTuple = namedtuple("Callsite", "function file linenum")
107-
self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
102+
self._callsite = first_spark_call() or CallSite(None, None, None)
108103
SparkContext._ensure_initialized(self, gateway=gateway)
109104
try:
110105
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,

python/pyspark/rdd.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from pyspark.resultiterable import ResultIterable
4444
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
4545
get_used_memory, ExternalSorter
46-
from pyspark.traceback_utils import JavaStackTrace
46+
from pyspark.traceback_utils import SCCallSiteSync
4747

4848
from py4j.java_collections import ListConverter, MapConverter
4949

@@ -652,7 +652,7 @@ def collect(self):
652652
"""
653653
Return a list that contains all of the elements in this RDD.
654654
"""
655-
with JavaStackTrace(self.context) as st:
655+
with SCCallSiteSync(self.context) as css:
656656
bytesInJava = self._jrdd.collect().iterator()
657657
return list(self._collect_iterator_through_file(bytesInJava))
658658

@@ -1463,7 +1463,7 @@ def add_shuffle_key(split, iterator):
14631463

14641464
keyed = self.mapPartitionsWithIndex(add_shuffle_key)
14651465
keyed._bypass_serializer = True
1466-
with JavaStackTrace(self.context) as st:
1466+
with SCCallSiteSync(self.context) as css:
14671467
pairRDD = self.ctx._jvm.PairwiseRDD(
14681468
keyed._jrdd.rdd()).asJavaPairRDD()
14691469
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,

python/pyspark/traceback_utils.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,14 @@
2020
import traceback
2121

2222

23-
__all__ = ["extract_concise_traceback", "SparkContext"]
23+
CallSite = namedtuple("CallSite", "function file linenum")
2424

2525

26-
def extract_concise_traceback():
26+
def first_spark_call():
2727
"""
28-
This function returns the traceback info for a callsite, returns a dict
29-
with function name, file name and line number
28+
Return a CallSite representing the first Spark call in the current call stack.
3029
"""
3130
tb = traceback.extract_stack()
32-
callsite = namedtuple("Callsite", "function file linenum")
3331
if len(tb) == 0:
3432
return None
3533
file, line, module, what = tb[len(tb) - 1]
@@ -42,39 +40,39 @@ def extract_concise_traceback():
4240
break
4341
if first_spark_frame == 0:
4442
file, line, fun, what = tb[0]
45-
return callsite(function=fun, file=file, linenum=line)
43+
return CallSite(function=fun, file=file, linenum=line)
4644
sfile, sline, sfun, swhat = tb[first_spark_frame]
4745
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
48-
return callsite(function=sfun, file=ufile, linenum=uline)
46+
return CallSite(function=sfun, file=ufile, linenum=uline)
4947

5048

51-
class JavaStackTrace(object):
49+
class SCCallSiteSync(object):
5250
"""
5351
Helper for setting the spark context call site.
5452
5553
Example usage:
56-
from pyspark.context import JavaStackTrace
57-
with JavaStackTrace(<relevant SparkContext>) as st:
54+
from pyspark.context import SCCallSiteSync
55+
with SCCallSiteSync(<relevant SparkContext>) as css:
5856
<a Spark call>
5957
"""
6058

6159
_spark_stack_depth = 0
6260

6361
def __init__(self, sc):
64-
tb = extract_concise_traceback()
65-
if tb is not None:
66-
self._traceback = "%s at %s:%s" % (
67-
tb.function, tb.file, tb.linenum)
62+
call_site = first_spark_call()
63+
if call_site is not None:
64+
self._call_site = "%s at %s:%s" % (
65+
call_site.function, call_site.file, call_site.linenum)
6866
else:
69-
self._traceback = "Error! Could not extract traceback info"
67+
self._call_site = "Error! Could not extract traceback info"
7068
self._context = sc
7169

7270
def __enter__(self):
73-
if JavaStackTrace._spark_stack_depth == 0:
74-
self._context._jsc.setCallSite(self._traceback)
75-
JavaStackTrace._spark_stack_depth += 1
71+
if SCCallSiteSync._spark_stack_depth == 0:
72+
self._context._jsc.setCallSite(self._call_site)
73+
SCCallSiteSync._spark_stack_depth += 1
7674

7775
def __exit__(self, type, value, tb):
78-
JavaStackTrace._spark_stack_depth -= 1
79-
if JavaStackTrace._spark_stack_depth == 0:
76+
SCCallSiteSync._spark_stack_depth -= 1
77+
if SCCallSiteSync._spark_stack_depth == 0:
8078
self._context._jsc.setCallSite(None)

0 commit comments

Comments
 (0)