20
20
import traceback
21
21
22
22
23
- __all__ = [ "extract_concise_traceback " , "SparkContext" ]
23
+ CallSite = namedtuple ( "CallSite " , "function file linenum" )
24
24
25
25
26
- def extract_concise_traceback ():
26
+ def first_spark_call ():
27
27
"""
28
- This function returns the traceback info for a callsite, returns a dict
29
- with function name, file name and line number
28
+ Return a CallSite representing the first Spark call in the current call stack.
30
29
"""
31
30
tb = traceback .extract_stack ()
32
- callsite = namedtuple ("Callsite" , "function file linenum" )
33
31
if len (tb ) == 0 :
34
32
return None
35
33
file , line , module , what = tb [len (tb ) - 1 ]
@@ -42,39 +40,39 @@ def extract_concise_traceback():
42
40
break
43
41
if first_spark_frame == 0 :
44
42
file , line , fun , what = tb [0 ]
45
- return callsite (function = fun , file = file , linenum = line )
43
+ return CallSite (function = fun , file = file , linenum = line )
46
44
sfile , sline , sfun , swhat = tb [first_spark_frame ]
47
45
ufile , uline , ufun , uwhat = tb [first_spark_frame - 1 ]
48
- return callsite (function = sfun , file = ufile , linenum = uline )
46
+ return CallSite (function = sfun , file = ufile , linenum = uline )
49
47
50
48
51
- class JavaStackTrace (object ):
49
+ class SCCallSiteSync (object ):
52
50
"""
53
51
Helper for setting the spark context call site.
54
52
55
53
Example usage:
56
- from pyspark.context import JavaStackTrace
57
- with JavaStackTrace (<relevant SparkContext>) as st :
54
+ from pyspark.context import SCCallSiteSync
55
+ with SCCallSiteSync (<relevant SparkContext>) as css :
58
56
<a Spark call>
59
57
"""
60
58
61
59
_spark_stack_depth = 0
62
60
63
61
def __init__ (self , sc ):
64
- tb = extract_concise_traceback ()
65
- if tb is not None :
66
- self ._traceback = "%s at %s:%s" % (
67
- tb .function , tb .file , tb .linenum )
62
+ call_site = first_spark_call ()
63
+ if call_site is not None :
64
+ self ._call_site = "%s at %s:%s" % (
65
+ call_site .function , call_site .file , call_site .linenum )
68
66
else :
69
- self ._traceback = "Error! Could not extract traceback info"
67
+ self ._call_site = "Error! Could not extract traceback info"
70
68
self ._context = sc
71
69
72
70
def __enter__ (self ):
73
- if JavaStackTrace ._spark_stack_depth == 0 :
74
- self ._context ._jsc .setCallSite (self ._traceback )
75
- JavaStackTrace ._spark_stack_depth += 1
71
+ if SCCallSiteSync ._spark_stack_depth == 0 :
72
+ self ._context ._jsc .setCallSite (self ._call_site )
73
+ SCCallSiteSync ._spark_stack_depth += 1
76
74
77
75
def __exit__ (self , type , value , tb ):
78
- JavaStackTrace ._spark_stack_depth -= 1
79
- if JavaStackTrace ._spark_stack_depth == 0 :
76
+ SCCallSiteSync ._spark_stack_depth -= 1
77
+ if SCCallSiteSync ._spark_stack_depth == 0 :
80
78
self ._context ._jsc .setCallSite (None )
0 commit comments