Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arrow zerocopy for reader and writer in object store #341

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.call.ActorCreator;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.runtime.object.ObjectRefImpl;

import java.util.Map;
import java.util.List;

import io.ray.api.placementgroup.PlacementGroup;
import io.ray.runtime.object.ObjectRefImpl;
import org.apache.spark.executor.RayDPExecutor;

public class RayExecutorUtils {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.List;
import scala.collection.JavaConverters._

import io.ray.runtime.generated.Common.Address
import org.apache.arrow.vector.VectorSchemaRoot

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.JavaSparkContext
Expand All @@ -37,15 +38,15 @@ class RayDatasetRDD(
jsc: JavaSparkContext,
@transient val objectIds: List[Array[Byte]],
locations: List[Array[Byte]])
extends RDD[Array[Byte]](jsc.sc, Nil) {
extends RDD[VectorSchemaRoot](jsc.sc, Nil) {

override def getPartitions: Array[Partition] = {
objectIds.asScala.zipWithIndex.map { case (k, i) =>
new RayDatasetRDDPartition(k, i).asInstanceOf[Partition]
}.toArray
}

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
override def compute(split: Partition, context: TaskContext): Iterator[VectorSchemaRoot] = {
val ref = split.asInstanceOf[RayDatasetRDDPartition].ref
ObjectStoreReader.getBatchesFromStream(ref)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@
package org.apache.spark.sql.raydp

import java.io.ByteArrayInputStream
import java.nio.ByteBuffer
import java.nio.channels.{Channels, ReadableByteChannel}
import java.util.List

import scala.collection.JavaConverters._

import com.intel.raydp.shims.SparkShimLoader
import org.apache.arrow.vector.VectorSchemaRoot

import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.raydp.RayDPUtils
import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD}
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}

object ObjectStoreReader {
def createRayObjectRefDF(
Expand All @@ -40,17 +48,56 @@ object ObjectStoreReader {
spark.createDataFrame(rdd, schema)
}

def fromRootIterator(
arrowRootIter: Iterator[VectorSchemaRoot],
schema: StructType,
timeZoneId: String): Iterator[InternalRow] = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)

new Iterator[InternalRow] {
private var rowIter = if (arrowRootIter.hasNext) nextBatch() else Iterator.empty

override def hasNext: Boolean = rowIter.hasNext || {
if (arrowRootIter.hasNext) {
rowIter = nextBatch()
true
} else {
false
}
}

override def next(): InternalRow = rowIter.next()

private def nextBatch(): Iterator[InternalRow] = {
val root = arrowRootIter.next()
val columns = root.getFieldVectors.asScala.map { vector =>
new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
}.toArray

val batch = new ColumnarBatch(columns)
batch.setNumRows(root.getRowCount)
root.close()
batch.rowIterator().asScala
}
}
}

def RayDatasetToDataFrame(
sparkSession: SparkSession,
rdd: RayDatasetRDD,
schema: String): DataFrame = {
SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession)
schemaString: String): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
val sqlContext = new SQLContext(sparkSession)
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
val resultRDD = JavaRDD.fromRDD(rdd).rdd.mapPartitions { it =>
fromRootIterator(it, schema, timeZoneId)
}
sqlContext.internalCreateDataFrame(resultRDD.setName("arrow"), schema)
}

def getBatchesFromStream(
ref: Array[Byte]): Iterator[Array[Byte]] = {
val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]])
ArrowConverters.getBatchesFromStream(
Channels.newChannel(new ByteArrayInputStream(objectRef.get)))
ref: Array[Byte]): Iterator[VectorSchemaRoot] = {
val objectRef = RayDPUtils.readBinary(ref, classOf[VectorSchemaRoot])
Iterator[VectorSchemaRoot](objectRef.get)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.raydp


import java.io.ByteArrayOutputStream
import java.util.{List, UUID}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
Expand Down Expand Up @@ -61,17 +60,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
val uuid: UUID = ObjectStoreWriter.dfToId.getOrElseUpdate(df, UUID.randomUUID())

def writeToRay(
data: Array[Byte],
root: VectorSchemaRoot,
numRecords: Int,
queue: ObjectRefHolder.Queue,
ownerName: String): RecordBatch = {

var objectRef: ObjectRef[Array[Byte]] = null
var objectRef: ObjectRef[VectorSchemaRoot] = null
if (ownerName == "") {
objectRef = Ray.put(data)
objectRef = Ray.put(root)
} else {
var dataOwner: PyActorHandle = Ray.getActor(ownerName).get()
objectRef = Ray.put(data, dataOwner)
objectRef = Ray.put(root, dataOwner)
}

// add the objectRef to the objectRefHolder to avoid reference GC
Expand Down Expand Up @@ -111,21 +109,15 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val results = new ArrayBuffer[RecordBatch]()

val byteOut = new ByteArrayOutputStream()
val arrowWriter = ArrowWriter.create(root)
var numRecords: Int = 0

Utils.tryWithSafeFinally {
while (batchIter.hasNext) {
// reset the state
numRecords = 0
byteOut.reset()
arrowWriter.reset()

// write out the schema meta data
val writer = new ArrowStreamWriter(root, null, byteOut)
writer.start()

// get the next record batch
val nextBatch = batchIter.next()

Expand All @@ -136,19 +128,11 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {

// set the write record count
arrowWriter.finish()
// write out the record batch to the underlying out
writer.writeBatch()

// get the wrote ByteArray and save to Ray ObjectStore
val byteArray = byteOut.toByteArray
results += writeToRay(byteArray, numRecords, queue, ownerName)
// end writes footer to the output stream and doesn't clean any resources.
// It could throw exception if the output stream is closed, so it should be
// in the try block.
writer.end()

// write and schema root directly and save to Ray ObjectStore
results += writeToRay(root, numRecords, queue, ownerName)
}
arrowWriter.reset()
byteOut.close()
} {
// If we close root and allocator in TaskCompletionListener, there could be a race
// condition where the writer thread keeps writing to the VectorSchemaRoot while
Expand All @@ -173,7 +157,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
/**
* For test.
*/
def getRandomRef(): List[Array[Byte]] = {
def getRandomRef(): List[VectorSchemaRoot] = {

df.queryExecution.toRdd.mapPartitions { _ =>
Iterator(ObjectRefHolder.getRandom(uuid))
Expand Down Expand Up @@ -233,7 +217,7 @@ object ObjectStoreWriter {
var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
val numExecutors = executorIds.length
val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
.get.asInstanceOf[ActorHandle[RayAppMaster]]
.get.asInstanceOf[ActorHandle[RayAppMaster]]
val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle)
// Check if there is any restarted executors
if (!restartedExecutors.isEmpty) {
Expand All @@ -251,8 +235,8 @@ object ObjectStoreWriter {
val refs = new Array[ObjectRef[Array[Byte]]](numPartitions)
val handles = executorIds.map {id =>
Ray.getActor("raydp-executor-" + id)
.get
.asInstanceOf[ActorHandle[RayDPExecutor]]
.get
.asInstanceOf[ActorHandle[RayDPExecutor]]
}
val handlesMap = (executorIds zip handles).toMap
val locations = RayExecutorUtils.getBlockLocations(
Expand All @@ -261,18 +245,15 @@ object ObjectStoreWriter {
// TODO use getPreferredLocs, but we don't have a host ip to actor table now
refs(i) = RayExecutorUtils.getRDDPartition(
handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl)
queue.add(refs(i))
}
for (i <- 0 until numPartitions) {
queue.add(RayDPUtils.readBinary(refs(i).get(), classOf[VectorSchemaRoot]))
results(i) = RayDPUtils.convert(refs(i)).getId.getBytes
}
results
}

}

object ObjectRefHolder {
type Queue = ConcurrentLinkedQueue[ObjectRef[Array[Byte]]]
type Queue = ConcurrentLinkedQueue[ObjectRef[VectorSchemaRoot]]
private val dfToQueue = new ConcurrentHashMap[UUID, Queue]()

def getQueue(df: UUID): Queue = {
Expand All @@ -297,7 +278,7 @@ object ObjectRefHolder {
queue.size()
}

def getRandom(df: UUID): Array[Byte] = {
def getRandom(df: UUID): VectorSchemaRoot = {
val queue = checkQueueExists(df)
val ref = RayDPUtils.convert(queue.peek())
ref.get()
Expand Down
11 changes: 2 additions & 9 deletions python/raydp/spark/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _convert_blocks_to_dataframe(blocks):
return df

def _convert_by_rdd(spark: sql.SparkSession,
blocks: Dataset,
blocks: List[ObjectRef],
locations: List[bytes],
schema: StructType) -> DataFrame:
object_ids = [block.binary() for block in blocks]
Expand Down Expand Up @@ -269,14 +269,7 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
schema = StructType()
for field in arrow_schema:
schema.add(field.name, from_arrow_type(field.type), nullable=field.nullable)
#TODO how to branch on type of block?
sample = ray.get(blocks[0])
if isinstance(sample, bytes):
return _convert_by_rdd(spark, blocks, locations, schema)
elif isinstance(sample, pa.Table):
return _convert_by_udf(spark, blocks, locations, schema)
else:
raise RuntimeError("ray.to_spark only supports arrow type blocks")
return _convert_by_rdd(spark, blocks, locations, schema)

if HAS_MLDATASET:
class RecordBatch(_SourceShard):
Expand Down