Skip to content

[SPARK-1087] Move python traceback utilities into new traceback_utils.py file. #2385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
from collections import namedtuple

from pyspark import accumulators
from pyspark.accumulators import Accumulator
Expand All @@ -33,6 +32,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call

from py4j.java_collections import ListConverter

Expand Down Expand Up @@ -99,11 +99,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
...
ValueError:...
"""
if rdd._extract_concise_traceback() is not None:
self._callsite = rdd._extract_concise_traceback()
else:
tempNamedTuple = namedtuple("Callsite", "function file linenum")
self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
self._callsite = first_spark_call() or CallSite(None, None, None)
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
Expand Down
58 changes: 3 additions & 55 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from collections import namedtuple
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
Expand All @@ -45,6 +43,7 @@
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
get_used_memory, ExternalSorter
from pyspark.traceback_utils import SCCallSiteSync

from py4j.java_collections import ListConverter, MapConverter

Expand Down Expand Up @@ -81,57 +80,6 @@ def portable_hash(x):
return hash(x)


def _extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
with function name, file name and line number
"""
tb = traceback.extract_stack()
callsite = namedtuple("Callsite", "function file linenum")
if len(tb) == 0:
return None
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return callsite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return callsite(function=sfun, file=ufile, linenum=uline)

_spark_stack_depth = 0


class _JavaStackTrace(object):

def __init__(self, sc):
tb = _extract_concise_traceback()
if tb is not None:
self._traceback = "%s at %s:%s" % (
tb.function, tb.file, tb.linenum)
else:
self._traceback = "Error! Could not extract traceback info"
self._context = sc

def __enter__(self):
global _spark_stack_depth
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(self._traceback)
_spark_stack_depth += 1

def __exit__(self, type, value, tb):
global _spark_stack_depth
_spark_stack_depth -= 1
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)


class BoundedFloat(float):
"""
Bounded value is generated by approximate job, with confidence and low
Expand Down Expand Up @@ -704,7 +652,7 @@ def collect(self):
"""
Return a list that contains all of the elements in this RDD.
"""
with _JavaStackTrace(self.context) as st:
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))

Expand Down Expand Up @@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator):

keyed = self.mapPartitionsWithIndex(add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
Expand Down
78 changes: 78 additions & 0 deletions python/pyspark/traceback_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import namedtuple
import os
import traceback


CallSite = namedtuple("CallSite", "function file linenum")


def first_spark_call():
"""
Return a CallSite representing the first Spark call in the current call stack.
"""
tb = traceback.extract_stack()
if len(tb) == 0:
return None
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return CallSite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return CallSite(function=sfun, file=ufile, linenum=uline)


class SCCallSiteSync(object):
"""
Helper for setting the spark context call site.

Example usage:
from pyspark.context import SCCallSiteSync
with SCCallSiteSync(<relevant SparkContext>) as css:
<a Spark call>
"""

_spark_stack_depth = 0

def __init__(self, sc):
call_site = first_spark_call()
if call_site is not None:
self._call_site = "%s at %s:%s" % (
call_site.function, call_site.file, call_site.linenum)
else:
self._call_site = "Error! Could not extract traceback info"
self._context = sc

def __enter__(self):
if SCCallSiteSync._spark_stack_depth == 0:
self._context._jsc.setCallSite(self._call_site)
SCCallSiteSync._spark_stack_depth += 1

def __exit__(self, type, value, tb):
SCCallSiteSync._spark_stack_depth -= 1
if SCCallSiteSync._spark_stack_depth == 0:
self._context._jsc.setCallSite(None)