Skip to content

[SPARK-3463] [PySpark] aggregate and show spilled bytes in Python #2336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ private[spark] class PythonRDD(
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def _get_local_dirs(sub):
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]


# global stats
MemoryBytesSpilled = 0L
DiskBytesSpilled = 0L


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these have to be global properties? I'm not familiar with this part of the python code, but in the scala version they're part of TaskContext. Is there not an equivalent here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python, there is no TaskContext or something like that, I have not find a better way to do it right now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to be fixed if we reuse Python workers, but it should be okay for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to merge the worker re-use patch soon, so we should fix this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already did clean these two before running a task, so it will also work with reused Python worker.

class Aggregator(object):

"""
Expand Down Expand Up @@ -313,10 +318,12 @@ def _spill(self):

It will dump the data in batch for better performance.
"""
global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)

used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
Expand All @@ -334,6 +341,7 @@ def _spill(self):
self.serializer.dump_stream([(k, v)], streams[h])

for s in streams:
DiskBytesSpilled += s.tell()
s.close()

self.data.clear()
Expand All @@ -346,9 +354,11 @@ def _spill(self):
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)

self.spills += 1
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20

def iteritems(self):
""" Return all merged items as iterator """
Expand Down Expand Up @@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
self._spilled_bytes = 0

def _get_path(self, n):
""" Choose one directory for spill by number n """
Expand All @@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
Expand All @@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False):
if len(chunk) < batch:
break

if get_used_memory() > self.memory_limit:
used_memory = get_used_memory()
if used_memory > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
gc.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we garbage collect here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to reclaim as much memory as possible, so latter sort and spill can process as much more data as possible. @mateiz had done some benchmark for ExternalMerger, this can improve the performance for some cases (such as list of lint).

Also, after gc.collect(), the used_memory() will be more accurate.

MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)

elif not chunks:
batch = min(batch * 2, 10000)
Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
from pyspark import shuffle

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -136,17 +137,17 @@ def test_external_sort(self):
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertGreater(sorter._spilled_bytes, 0)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, 0)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
self.assertGreater(shuffle.DiskBytesSpilled, last)

def test_external_sort_in_rdd(self):
conf = SparkConf().set("spark.python.worker.memory", "1m")
Expand Down
14 changes: 10 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@
import time
import socket
import traceback
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few lines prior to this, there was a comment

# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.

If this import is no longer necessary (was it ever?), then we should delete that comment, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldpickle is imported by serializers, so it's not needed here. The comments are removed.

write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer

from pyspark import shuffle

pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()
Expand All @@ -52,6 +50,11 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
return

# initialize global state
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
_accumulatorRegistry.clear()

# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
Expand Down Expand Up @@ -92,6 +95,9 @@ def main(infile, outfile):
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
write_long(shuffle.DiskBytesSpilled, outfile)

# Mark the beginning of the accumulators section of the output
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)
Expand Down