Skip to content

Commit 4e3fbe8

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-3463] [PySpark] aggregate and show spilled bytes in Python
Aggregate the number of bytes spilled into disks during aggregation or sorting, show them in Web UI. ![spilled](https://cloud.githubusercontent.com/assets/40902/4209758/4b995562-386d-11e4-97c1-8e838ee1d4e3.png) This patch is blocked by SPARK-3465. (It includes a fix for that). Author: Davies Liu <davies.liu@gmail.com> Closes #2336 from davies/metrics and squashes the following commits: e37df38 [Davies Liu] remove outdated comments 1245eb7 [Davies Liu] remove the temporary fix ebd2f43 [Davies Liu] Merge branch 'master' into metrics 7e4ad04 [Davies Liu] Merge branch 'master' into metrics fbe9029 [Davies Liu] show spilled bytes in Python in web ui
1 parent 2aea0da commit 4e3fbe8

File tree

4 files changed

+38
-14
lines changed

4 files changed

+38
-14
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ private[spark] class PythonRDD(
124124
val total = finishTime - startTime
125125
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
126126
init, finish))
127+
val memoryBytesSpilled = stream.readLong()
128+
val diskBytesSpilled = stream.readLong()
129+
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
130+
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
127131
read()
128132
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
129133
// Signals that an exception has been thrown in python

python/pyspark/shuffle.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def _get_local_dirs(sub):
6868
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
6969

7070

71+
# global stats
72+
MemoryBytesSpilled = 0L
73+
DiskBytesSpilled = 0L
74+
75+
7176
class Aggregator(object):
7277

7378
"""
@@ -313,10 +318,12 @@ def _spill(self):
313318
314319
It will dump the data in batch for better performance.
315320
"""
321+
global MemoryBytesSpilled, DiskBytesSpilled
316322
path = self._get_spill_dir(self.spills)
317323
if not os.path.exists(path):
318324
os.makedirs(path)
319325

326+
used_memory = get_used_memory()
320327
if not self.pdata:
321328
# The data has not been partitioned, it will iterator the
322329
# dataset once, write them into different files, has no
@@ -334,6 +341,7 @@ def _spill(self):
334341
self.serializer.dump_stream([(k, v)], streams[h])
335342

336343
for s in streams:
344+
DiskBytesSpilled += s.tell()
337345
s.close()
338346

339347
self.data.clear()
@@ -346,9 +354,11 @@ def _spill(self):
346354
# dump items in batch
347355
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
348356
self.pdata[i].clear()
357+
DiskBytesSpilled += os.path.getsize(p)
349358

350359
self.spills += 1
351360
gc.collect() # release the memory as much as possible
361+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
352362

353363
def iteritems(self):
354364
""" Return all merged items as iterator """
@@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None):
462472
self.memory_limit = memory_limit
463473
self.local_dirs = _get_local_dirs("sort")
464474
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
465-
self._spilled_bytes = 0
466475

467476
def _get_path(self, n):
468477
""" Choose one directory for spill by number n """
@@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
476485
Sort the elements in iterator, do external sort when the memory
477486
goes above the limit.
478487
"""
488+
global MemoryBytesSpilled, DiskBytesSpilled
479489
batch = 10
480490
chunks, current_chunk = [], []
481491
iterator = iter(iterator)
@@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False):
486496
if len(chunk) < batch:
487497
break
488498

489-
if get_used_memory() > self.memory_limit:
499+
used_memory = get_used_memory()
500+
if used_memory > self.memory_limit:
490501
# sort them inplace will save memory
491502
current_chunk.sort(key=key, reverse=reverse)
492503
path = self._get_path(len(chunks))
493504
with open(path, 'w') as f:
494505
self.serializer.dump_stream(current_chunk, f)
495-
self._spilled_bytes += os.path.getsize(path)
496506
chunks.append(self.serializer.load_stream(open(path)))
497507
current_chunk = []
508+
gc.collect()
509+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
510+
DiskBytesSpilled += os.path.getsize(path)
498511

499512
elif not chunks:
500513
batch = min(batch * 2, 10000)

python/pyspark/tests.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
CloudPickleSerializer
4747
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
4848
from pyspark.sql import SQLContext, IntegerType
49+
from pyspark import shuffle
4950

5051
_have_scipy = False
5152
_have_numpy = False
@@ -138,17 +139,17 @@ def test_external_sort(self):
138139
random.shuffle(l)
139140
sorter = ExternalSorter(1)
140141
self.assertEquals(sorted(l), list(sorter.sorted(l)))
141-
self.assertGreater(sorter._spilled_bytes, 0)
142-
last = sorter._spilled_bytes
142+
self.assertGreater(shuffle.DiskBytesSpilled, 0)
143+
last = shuffle.DiskBytesSpilled
143144
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
144-
self.assertGreater(sorter._spilled_bytes, last)
145-
last = sorter._spilled_bytes
145+
self.assertGreater(shuffle.DiskBytesSpilled, last)
146+
last = shuffle.DiskBytesSpilled
146147
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
147-
self.assertGreater(sorter._spilled_bytes, last)
148-
last = sorter._spilled_bytes
148+
self.assertGreater(shuffle.DiskBytesSpilled, last)
149+
last = shuffle.DiskBytesSpilled
149150
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
150151
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
151-
self.assertGreater(sorter._spilled_bytes, last)
152+
self.assertGreater(shuffle.DiskBytesSpilled, last)
152153

153154
def test_external_sort_in_rdd(self):
154155
conf = SparkConf().set("spark.python.worker.memory", "1m")

python/pyspark/worker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,14 @@
2323
import time
2424
import socket
2525
import traceback
26-
# CloudPickler needs to be imported so that depicklers are registered using the
27-
# copy_reg module.
26+
2827
from pyspark.accumulators import _accumulatorRegistry
2928
from pyspark.broadcast import Broadcast, _broadcastRegistry
30-
from pyspark.cloudpickle import CloudPickler
3129
from pyspark.files import SparkFiles
3230
from pyspark.serializers import write_with_length, write_int, read_long, \
3331
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
3432
CompressedSerializer
35-
33+
from pyspark import shuffle
3634

3735
pickleSer = PickleSerializer()
3836
utf8_deserializer = UTF8Deserializer()
@@ -52,6 +50,11 @@ def main(infile, outfile):
5250
if split_index == -1: # for unit tests
5351
return
5452

53+
# initialize global state
54+
shuffle.MemoryBytesSpilled = 0
55+
shuffle.DiskBytesSpilled = 0
56+
_accumulatorRegistry.clear()
57+
5558
# fetch name of workdir
5659
spark_files_dir = utf8_deserializer.loads(infile)
5760
SparkFiles._root_directory = spark_files_dir
@@ -97,6 +100,9 @@ def main(infile, outfile):
97100
exit(-1)
98101
finish_time = time.time()
99102
report_times(outfile, boot_time, init_time, finish_time)
103+
write_long(shuffle.MemoryBytesSpilled, outfile)
104+
write_long(shuffle.DiskBytesSpilled, outfile)
105+
100106
# Mark the beginning of the accumulators section of the output
101107
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
102108
write_int(len(_accumulatorRegistry), outfile)

0 commit comments

Comments
 (0)