Skip to content

[SPARK-10714][SPARK-8632][SPARK-10685][SQL] Refactor Python UDF handling #8835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
import scala.util.control.NonFatal

import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.conf.Configuration
Expand All @@ -38,7 +39,6 @@ import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{SerializableConfiguration, Utils}

import scala.util.control.NonFatal

private[spark] class PythonRDD(
parent: RDD[_],
Expand All @@ -61,11 +61,39 @@ private[spark] class PythonRDD(
if (preservePartitoning) firstParent.partitioner else None
}

val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner(
command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator,
bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}


/**
* A helper class to run Python UDFs in Spark.
*/
private[spark] class PythonRunner(
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
bufferSize: Int,
reuse_worker: Boolean)
extends Logging {

def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
if (reuse_worker) {
envVars.put("SPARK_REUSE_WORKER", "1")
Expand All @@ -75,7 +103,7 @@ private[spark] class PythonRDD(
@volatile var released = false

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)

context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
Expand Down Expand Up @@ -183,13 +211,16 @@ private[spark] class PythonRDD(
new InterruptibleIterator(context, stdoutIterator)
}

val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
class WriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {

@volatile private var _exception: Exception = null
Expand All @@ -211,11 +242,11 @@ private[spark] class PythonRDD(
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(split.index)
dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size())
for (include <- pythonIncludes.asScala) {
Expand Down Expand Up @@ -246,7 +277,7 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
Expand Down Expand Up @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging {

// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
private def getWorkerBroadcasts(worker: Socket) = {

def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil}
import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.{Accumulator, Logging => SparkLogging}
import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator}

/**
* A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
Expand Down Expand Up @@ -329,7 +329,13 @@ case class EvaluatePython(
/**
* :: DeveloperApi ::
* Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time.
* The input data is zipped with the result of the udf evaluation.
*
* Python evaluation works by sending the necessary (projected) input data via a socket to an
* external Python process, and combine the result from the Python process with the original row.
*
* For each row we send to Python, we also put it in a queue. For each output row from Python,
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we mitigate this by using a LinkedBlockingDeque to have the producer-side block on inserts once the queue grows to a certain size?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per discussion offline, the only scenario where the queue can grow really large is when the Python buffer size has been configured to be very large and the UDF result rows are very small. As a result, I think that this comment should be expanded / clarified, but this can take place in a followup PR.

*/
@DeveloperApi
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
Expand All @@ -342,51 +348,57 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
override def canProcessSafeRows: Boolean = true

protected override def doExecute(): RDD[InternalRow] = {
val childResults = child.execute().map(_.copy())
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)

val parent = childResults.mapPartitions { iter =>
inputRDD.mapPartitions { iter =>
EvaluatePython.registerPicklers() // register pickler for Row

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val pickle = new Pickler
val currentRow = newMutableProjection(udf.children, child.output)()
val fields = udf.children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
iter.grouped(100).map { inputRows =>

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
queue.add(row)
EvaluatePython.toJava(currentRow(row), schema)
}.toArray
pickle.dumps(toBePickled)
}
}

val pyRDD = new PythonRDD(
parent,
udf.command,
udf.envVars,
udf.pythonIncludes,
false,
udf.pythonExec,
udf.pythonVer,
udf.broadcastVars,
udf.accumulator
).mapPartitions { iter =>
val pickle = new Unpickler
iter.flatMap { pickedResult =>
val unpickledBatch = pickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}
}.mapPartitions { iter =>
val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
udf.command,
udf.envVars,
udf.pythonIncludes,
udf.pythonExec,
udf.pythonVer,
udf.broadcastVars,
udf.accumulator,
bufferSize,
reuseWorker
).compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
iter.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
row: InternalRow
}
}
val joined = new JoinedRow

childResults.zip(pyRDD).mapPartitions { iter =>
val joinedRow = new JoinedRow()
iter.map {
case (row, udfResult) =>
joinedRow(row, udfResult)
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
joined(queue.poll(), row)
}
}
}
Expand Down