-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-3463] [PySpark] aggregate and show spilled bytes in Python #2336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,11 @@ def _get_local_dirs(sub): | |
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] | ||
|
||
|
||
# global stats | ||
MemoryBytesSpilled = 0L | ||
DiskBytesSpilled = 0L | ||
|
||
|
||
class Aggregator(object): | ||
|
||
""" | ||
|
@@ -313,10 +318,12 @@ def _spill(self): | |
|
||
It will dump the data in batch for better performance. | ||
""" | ||
global MemoryBytesSpilled, DiskBytesSpilled | ||
path = self._get_spill_dir(self.spills) | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
|
||
used_memory = get_used_memory() | ||
if not self.pdata: | ||
# The data has not been partitioned, it will iterator the | ||
# dataset once, write them into different files, has no | ||
|
@@ -334,6 +341,7 @@ def _spill(self): | |
self.serializer.dump_stream([(k, v)], streams[h]) | ||
|
||
for s in streams: | ||
DiskBytesSpilled += s.tell() | ||
s.close() | ||
|
||
self.data.clear() | ||
|
@@ -346,9 +354,11 @@ def _spill(self): | |
# dump items in batch | ||
self.serializer.dump_stream(self.pdata[i].iteritems(), f) | ||
self.pdata[i].clear() | ||
DiskBytesSpilled += os.path.getsize(p) | ||
|
||
self.spills += 1 | ||
gc.collect() # release the memory as much as possible | ||
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 | ||
|
||
def iteritems(self): | ||
""" Return all merged items as iterator """ | ||
|
@@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None): | |
self.memory_limit = memory_limit | ||
self.local_dirs = _get_local_dirs("sort") | ||
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) | ||
self._spilled_bytes = 0 | ||
|
||
def _get_path(self, n): | ||
""" Choose one directory for spill by number n """ | ||
|
@@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False): | |
Sort the elements in iterator, do external sort when the memory | ||
goes above the limit. | ||
""" | ||
global MemoryBytesSpilled, DiskBytesSpilled | ||
batch = 10 | ||
chunks, current_chunk = [], [] | ||
iterator = iter(iterator) | ||
|
@@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False): | |
if len(chunk) < batch: | ||
break | ||
|
||
if get_used_memory() > self.memory_limit: | ||
used_memory = get_used_memory() | ||
if used_memory > self.memory_limit: | ||
# sort them inplace will save memory | ||
current_chunk.sort(key=key, reverse=reverse) | ||
path = self._get_path(len(chunks)) | ||
with open(path, 'w') as f: | ||
self.serializer.dump_stream(current_chunk, f) | ||
self._spilled_bytes += os.path.getsize(path) | ||
chunks.append(self.serializer.load_stream(open(path))) | ||
current_chunk = [] | ||
gc.collect() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we garbage collect here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try to reclaim as much memory as possible, so latter sort and spill can process as much more data as possible. @mateiz had done some benchmark for ExternalMerger, this can improve the performance for some cases (such as list of lint). Also, after gc.collect(), the used_memory() will be more accurate. |
||
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 | ||
DiskBytesSpilled += os.path.getsize(path) | ||
|
||
elif not chunks: | ||
batch = min(batch * 2, 10000) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,16 +23,14 @@ | |
import time | ||
import socket | ||
import traceback | ||
# CloudPickler needs to be imported so that depicklers are registered using the | ||
# copy_reg module. | ||
|
||
from pyspark.accumulators import _accumulatorRegistry | ||
from pyspark.broadcast import Broadcast, _broadcastRegistry | ||
from pyspark.cloudpickle import CloudPickler | ||
from pyspark.files import SparkFiles | ||
from pyspark.serializers import write_with_length, write_int, read_long, \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A few lines prior to this, there was a comment
If this import is no longer necessary (was it ever?), then we should delete that comment, too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. couldpickle is imported by serializers, so it's not needed here. The comments are removed. |
||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ | ||
CompressedSerializer | ||
|
||
from pyspark import shuffle | ||
|
||
pickleSer = PickleSerializer() | ||
utf8_deserializer = UTF8Deserializer() | ||
|
@@ -52,6 +50,11 @@ def main(infile, outfile): | |
if split_index == -1: # for unit tests | ||
return | ||
|
||
# initialize global state | ||
shuffle.MemoryBytesSpilled = 0 | ||
shuffle.DiskBytesSpilled = 0 | ||
_accumulatorRegistry.clear() | ||
|
||
# fetch name of workdir | ||
spark_files_dir = utf8_deserializer.loads(infile) | ||
SparkFiles._root_directory = spark_files_dir | ||
|
@@ -92,6 +95,9 @@ def main(infile, outfile): | |
exit(-1) | ||
finish_time = time.time() | ||
report_times(outfile, boot_time, init_time, finish_time) | ||
write_long(shuffle.MemoryBytesSpilled, outfile) | ||
write_long(shuffle.DiskBytesSpilled, outfile) | ||
|
||
# Mark the beginning of the accumulators section of the output | ||
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) | ||
write_int(len(_accumulatorRegistry), outfile) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these have to be global properties? I'm not familiar with this part of the python code, but in the scala version they're part of
TaskContext
. Is there not an equivalent here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Python, there is no TaskContext or something like that, I have not find a better way to do it right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will need to be fixed if we reuse Python workers, but it should be okay for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to merge the worker re-use patch soon, so we should fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already did clean these two before running a task, so it will also work with reused Python worker.