Skip to content

Commit

Permalink
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Browse files Browse the repository at this point in the history
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM.

This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python.

cc JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes apache#4923 from davies/fix_collect and squashes the following commits:

d730286 [Davies Liu] address comments
24c92a4 [Davies Liu] fix style
ba54614 [Davies Liu] use socket to transfer data from JVM
9517c8f [Davies Liu] fix memory leak in collect()
  • Loading branch information
Davies Liu authored and JoshRosen committed Mar 9, 2015
1 parent 3cac199 commit 8767565
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 51 deletions.
76 changes: 59 additions & 17 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,27 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}

import org.apache.spark.input.PortableDataStream
import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}

import org.apache.spark._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

import scala.util.control.NonFatal

private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
Expand Down Expand Up @@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
* This method will return an iterator of an array that contains all elements in the RDD
* This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return the port number of a local socket which serves the data collected from this job.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
allowLocal: Boolean): Iterator[Array[Byte]] = {
allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
flattenedPartition.iterator
serveIterator(flattenedPartition.iterator,
s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
}

/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return the port number of a local socket which serves the data collected from this job.
*/
def collectAndServe[T](rdd: RDD[T]): Int = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
Expand Down Expand Up @@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}

def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeToFile(items.asScala, filename)
}
/**
* Create a socket server and a background thread to serve the data in `items`,
*
* The socket server can only accept one connection, or close if no connection
* in 3 seconds.
*
* Once a connection comes in, it tries to serialize all the data in `items`
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
*/
private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
val serverSocket = new ServerSocket(0, 1)
serverSocket.setReuseAddress(true)
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)

new Thread(threadName) {
setDaemon(true)
override def run() {
try {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
writeIteratorToStream(items, out)
} finally {
out.close()
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator", e)
} finally {
serverSocket.close()
}
}
}.start()

def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
writeIteratorToStream(items, file)
file.close()
serverSocket.getLocalPort
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
13 changes: 6 additions & 7 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from threading import Lock
from tempfile import NamedTemporaryFile

from py4j.java_collections import ListConverter

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
Expand All @@ -30,13 +32,11 @@
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler

from py4j.java_collections import ListConverter


__all__ = ['SparkContext']

Expand All @@ -59,7 +59,6 @@ class SparkContext(object):

_gateway = None
_jvm = None
_writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
Expand Down Expand Up @@ -221,7 +220,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile

if instance:
if (SparkContext._active_spark_context and
Expand Down Expand Up @@ -840,8 +838,9 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
allowLocal)
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))

def show_profiles(self):
""" Print the profile stats to stdout """
Expand Down
30 changes: 14 additions & 16 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
from subprocess import Popen, PIPE
Expand All @@ -29,6 +28,7 @@
import heapq
import bisect
import random
import socket
from math import sqrt, log, isinf, isnan, pow, ceil

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
Expand Down Expand Up @@ -111,6 +111,17 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


def _load_from_socket(port, serializer):
sock = socket.socket()
try:
sock.connect(("localhost", port))
rf = sock.makefile("rb", 65536)
for item in serializer.load_stream(rf):
yield item
finally:
sock.close()


class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
Expand Down Expand Up @@ -698,21 +709,8 @@ def collect(self):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))

def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
return list(_load_from_socket(port, self._jrdd_deserializer))

def reduce(self, f):
"""
Expand Down
14 changes: 3 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import itertools
import warnings
import random
import os
from tempfile import NamedTemporaryFile

from py4j.java_collections import ListConverter, MapConverter

from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
Expand Down Expand Up @@ -310,14 +308,8 @@ def collect(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
bytesInJava = self._jdf.javaToPython().collect().iterator()
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile.close()
self._sc._writeToFile(bytesInJava, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
cls = _create_cls(self.schema)
return [cls(r) for r in rs]

Expand Down

0 comments on commit 8767565

Please sign in to comment.