Skip to content

Commit a9a8fa5

Browse files
HyukjinKwonVinoo Ganesh
authored and
Vinoo Ganesh
committed
[SPARK-26762][SQL][R] Arrow optimization for conversion from Spark DataFrame to R DataFrame
## What changes were proposed in this pull request? This PR targets to support Arrow optimization for conversion from Spark DataFrame to R DataFrame. Like PySpark side, it falls back to non-optimization code path when it's unable to use Arrow optimization. This can be tested as below: ```bash $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` ```r collect(createDataFrame(mtcars)) ``` ### Requirements - R 3.5.x - Arrow package 0.12+ ```bash Rscript -e 'remotes::install_github("apache/arrowapache-arrow-0.12.0", subdir = "r")' ``` **Note:** currently, Arrow R package is not in CRAN. Please take a look at ARROW-3204. **Note:** currently, Arrow R package seems not supporting Windows. Please take a look at ARROW-3204. ### Benchmarks **Shall** ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=false --driver-memory 4g ``` ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true --driver-memory 4g ``` **R code** ```r df <- cache(createDataFrame(read.csv("500000.csv"))) count(df) test <- function() { options(digits.secs = 6) # milliseconds start.time <- Sys.time() collect(df) end.time <- Sys.time() time.taken <- end.time - start.time print(time.taken) } test() ``` **Data (350 MB):** ```r object.size(read.csv("500000.csv")) 350379504 bytes ``` "500000 Records" http://eforexcel.com/wp/downloads-16-sample-csv-files-data-sets-for-testing/ **Results** ``` Time difference of 221.32014 secs ``` ``` Time difference of 15.51145 secs ``` The performance improvement was around **1426%**. ### Limitations: - For now, Arrow optimization with R does not support when the data is `raw`, and when user explicitly gives float type in the schema. They produce corrupt values. In this case, we decide to fall back to non-optimization code path. - Due to ARROW-4512, it cannot send and receive batch by batch. It has to send all batches in Arrow stream format at once. It needs improvement later. ## How was this patch tested? Existing tests related with Arrow optimization cover this change. Also, manually tested. Closes apache#23760 from HyukjinKwon/SPARK-26762. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent d210c0a commit a9a8fa5

File tree

5 files changed

+153
-11
lines changed

5 files changed

+153
-11
lines changed

R/pkg/R/DataFrame.R

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,11 +1177,67 @@ setMethod("dim",
11771177
setMethod("collect",
11781178
signature(x = "SparkDataFrame"),
11791179
function(x, stringsAsFactors = FALSE) {
1180+
connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
1181+
useArrow <- FALSE
1182+
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
1183+
if (arrowEnabled) {
1184+
useArrow <- tryCatch({
1185+
requireNamespace1 <- requireNamespace
1186+
if (!requireNamespace1("arrow", quietly = TRUE)) {
1187+
stop("'arrow' package should be installed.")
1188+
}
1189+
# Currenty Arrow optimization does not support raw for now.
1190+
# Also, it does not support explicit float type set by users.
1191+
if (inherits(schema(x), "structType")) {
1192+
if (any(sapply(schema(x)$fields(),
1193+
function(x) x$dataType.toString() == "FloatType"))) {
1194+
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
1195+
"DataFrame does not support FloatType yet."))
1196+
}
1197+
if (any(sapply(schema(x)$fields(),
1198+
function(x) x$dataType.toString() == "BinaryType"))) {
1199+
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
1200+
"DataFrame does not support BinaryType yet."))
1201+
}
1202+
}
1203+
TRUE
1204+
}, error = function(e) {
1205+
warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ",
1206+
"with Arrow optimization because ",
1207+
"'spark.sql.execution.arrow.enabled' is set to true; however, ",
1208+
"failed, attempting non-optimization. Reason: ",
1209+
e))
1210+
FALSE
1211+
})
1212+
}
1213+
11801214
dtypes <- dtypes(x)
11811215
ncol <- length(dtypes)
11821216
if (ncol <= 0) {
11831217
# empty data.frame with 0 columns and 0 rows
11841218
data.frame()
1219+
} else if (useArrow) {
1220+
requireNamespace1 <- requireNamespace
1221+
if (requireNamespace1("arrow", quietly = TRUE)) {
1222+
read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE)
1223+
as_tibble <- get("as_tibble", envir = asNamespace("arrow"))
1224+
1225+
portAuth <- callJMethod(x@sdf, "collectAsArrowToR")
1226+
port <- portAuth[[1]]
1227+
authSecret <- portAuth[[2]]
1228+
conn <- socketConnection(
1229+
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
1230+
output <- tryCatch({
1231+
doServerAuth(conn, authSecret)
1232+
arrowTable <- read_arrow(readRaw(conn))
1233+
as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors)
1234+
}, finally = {
1235+
close(conn)
1236+
})
1237+
return(output)
1238+
} else {
1239+
stop("'arrow' package should be installed.")
1240+
}
11851241
} else {
11861242
# listCols is a list of columns
11871243
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)

R/pkg/tests/fulltests/test_sparkSQL.R

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ test_that("create DataFrame from RDD", {
307307
unsetHiveContext()
308308
})
309309

310-
test_that("createDataFrame Arrow optimization", {
310+
test_that("createDataFrame/collect Arrow optimization", {
311311
skip_if_not_installed("arrow")
312312

313313
conf <- callJMethod(sparkSession, "conf")
@@ -332,7 +332,24 @@ test_that("createDataFrame Arrow optimization", {
332332
})
333333
})
334334

335-
test_that("createDataFrame Arrow optimization - type specification", {
335+
test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", {
336+
skip_if_not_installed("arrow")
337+
338+
conf <- callJMethod(sparkSession, "conf")
339+
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
340+
341+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
342+
tryCatch({
343+
expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)),
344+
collect(createDataFrame(mtcars, numPartitions = 1)))
345+
},
346+
finally = {
347+
# Resetting the conf back to default value
348+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
349+
})
350+
})
351+
352+
test_that("createDataFrame/collect Arrow optimization - type specification", {
336353
skip_if_not_installed("arrow")
337354
rdf <- data.frame(list(list(a = 1,
338355
b = "a",

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,12 @@ private[spark] object PythonRDD extends Logging {
432432
*/
433433
private[spark] def serveToStream(
434434
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
435+
serveToStream(threadName, authHelper)(writeFunc)
436+
}
437+
438+
private[spark] def serveToStream(
439+
threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit)
440+
: Array[Any] = {
435441
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
436442
val out = new BufferedOutputStream(s.getOutputStream())
437443
Utils.tryWithSafeFinally {

core/src/main/scala/org/apache/spark/api/r/RRDD.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.api.r
1919

20-
import java.io.{DataInputStream, File}
20+
import java.io.{DataInputStream, File, OutputStream}
2121
import java.net.Socket
2222
import java.nio.charset.StandardCharsets.UTF_8
2323
import java.util.{Map => JMap}
@@ -113,7 +113,7 @@ private class StringRRDD[T: ClassTag](
113113
lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
114114
}
115115

116-
private[r] object RRDD {
116+
private[spark] object RRDD {
117117
def createSparkContext(
118118
master: String,
119119
appName: String,
@@ -174,6 +174,11 @@ private[r] object RRDD {
174174
JavaRDD[Array[Byte]] = {
175175
PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
176176
}
177+
178+
private[spark] def serveToStream(
179+
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
180+
PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc)
181+
}
177182
}
178183

179184
/**

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.io.{CharArrayWriter, DataOutputStream}
20+
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
2121

2222
import scala.collection.JavaConverters._
2323
import scala.collection.mutable.ArrayBuffer
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable
3131
import org.apache.spark.api.java.JavaRDD
3232
import org.apache.spark.api.java.function._
3333
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
34+
import org.apache.spark.api.r.RRDD
3435
import org.apache.spark.broadcast.Broadcast
3536
import org.apache.spark.rdd.RDD
3637
import org.apache.spark.sql.catalyst.QueryPlanningTracker
@@ -3198,9 +3199,66 @@ class Dataset[T] private[sql](
31983199
}
31993200

32003201
/**
3201-
* Collect a Dataset as Arrow batches and serve stream to PySpark.
3202+
* Collect a Dataset as Arrow batches and serve stream to SparkR. It sends
3203+
* arrow batches in an ordered manner with buffering. This is inevitable
3204+
* due to missing R API that reads batches from socket directly. See ARROW-4512.
3205+
* Eventually, this code should be deduplicated by `collectAsArrowToPython`.
32023206
*/
3203-
private[sql] def collectAsArrowToPython(): Array[Any] = {
3207+
private[sql] def collectAsArrowToR(): Array[Any] = {
3208+
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
3209+
3210+
withAction("collectAsArrowToR", queryExecution) { plan =>
3211+
RRDD.serveToStream("serve-Arrow") { outputStream =>
3212+
val buffer = new ByteArrayOutputStream()
3213+
val out = new DataOutputStream(outputStream)
3214+
val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId)
3215+
val arrowBatchRdd = toArrowBatchRdd(plan)
3216+
val numPartitions = arrowBatchRdd.partitions.length
3217+
3218+
// Store collection results for worst case of 1 to N-1 partitions
3219+
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
3220+
var lastIndex = -1 // index of last partition written
3221+
3222+
// Handler to eagerly write partitions to Python in order
3223+
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
3224+
// If result is from next partition in order
3225+
if (index - 1 == lastIndex) {
3226+
batchWriter.writeBatches(arrowBatches.iterator)
3227+
lastIndex += 1
3228+
// Write stored partitions that come next in order
3229+
while (lastIndex < results.length && results(lastIndex) != null) {
3230+
batchWriter.writeBatches(results(lastIndex).iterator)
3231+
results(lastIndex) = null
3232+
lastIndex += 1
3233+
}
3234+
// After last batch, end the stream
3235+
if (lastIndex == results.length) {
3236+
batchWriter.end()
3237+
val batches = buffer.toByteArray
3238+
out.writeInt(batches.length)
3239+
out.write(batches)
3240+
}
3241+
} else {
3242+
// Store partitions received out of order
3243+
results(index - 1) = arrowBatches
3244+
}
3245+
}
3246+
3247+
sparkSession.sparkContext.runJob(
3248+
arrowBatchRdd,
3249+
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
3250+
0 until numPartitions,
3251+
handlePartitionBatches)
3252+
}
3253+
}
3254+
}
3255+
3256+
/**
3257+
* Collect a Dataset as Arrow batches and serve stream to PySpark. It sends
3258+
* arrow batches in an un-ordered manner without buffering, and then batch order
3259+
* information at the end. The batches should be reordered at Python side.
3260+
*/
3261+
private[sql] def collectAsArrowToPython: Array[Any] = {
32043262
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
32053263

32063264
withAction("collectAsArrowToPython", queryExecution) { plan =>
@@ -3211,7 +3269,7 @@ class Dataset[T] private[sql](
32113269
val numPartitions = arrowBatchRdd.partitions.length
32123270

32133271
// Batches ordered by (index of partition, batch index in that partition) tuple
3214-
val batchOrder = new ArrayBuffer[(Int, Int)]()
3272+
val batchOrder = ArrayBuffer.empty[(Int, Int)]
32153273
var partitionCount = 0
32163274

32173275
// Handler to eagerly write batches to Python as they arrive, un-ordered
@@ -3220,7 +3278,7 @@ class Dataset[T] private[sql](
32203278
// Write all batches (can be more than 1) in the partition, store the batch order tuple
32213279
batchWriter.writeBatches(arrowBatches.iterator)
32223280
arrowBatches.indices.foreach {
3223-
partition_batch_index => batchOrder.append((index, partition_batch_index))
3281+
partitionBatchIndex => batchOrder.append((index, partitionBatchIndex))
32243282
}
32253283
}
32263284
partitionCount += 1
@@ -3232,8 +3290,8 @@ class Dataset[T] private[sql](
32323290
// Sort by (index of partition, batch index in that partition) tuple to get the
32333291
// overall_batch_index from 0 to N-1 batches, which can be used to put the
32343292
// transferred batches in the correct order
3235-
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
3236-
out.writeInt(overall_batch_index)
3293+
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
3294+
out.writeInt(overallBatchIndex)
32373295
}
32383296
out.flush()
32393297
}

0 commit comments

Comments
 (0)