Skip to content

Commit e8982ca

Browse files
committed
[SPARK-25981][R] Enables Arrow optimization from R DataFrame to Spark DataFrame
## What changes were proposed in this pull request? This PR targets to support Arrow optimization for conversion from R DataFrame to Spark DataFrame. Like PySpark side, it falls back to non-optimization code path when it's unable to use Arrow optimization. This can be tested as below: ```bash $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` ```r collect(createDataFrame(mtcars)) ``` ### Requirements - R 3.5.x - Arrow package 0.12+ ```bash Rscript -e 'remotes::install_github("apache/arrowapache-arrow-0.12.0", subdir = "r")' ``` **Note:** currently, Arrow R package is not in CRAN. Please take a look at ARROW-3204. **Note:** currently, Arrow R package seems not supporting Windows. Please take a look at ARROW-3204. ### Benchmarks **Shall** ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=false ``` ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` **R code** ```r createDataFrame(mtcars) # Initializes rdf <- read.csv("500000.csv") test <- function() { options(digits.secs = 6) # milliseconds start.time <- Sys.time() createDataFrame(rdf) end.time <- Sys.time() time.taken <- end.time - start.time print(time.taken) } test() ``` **Data (350 MB):** ```r object.size(read.csv("500000.csv")) 350379504 bytes ``` "500000 Records" http://eforexcel.com/wp/downloads-16-sample-csv-files-data-sets-for-testing/ **Results** ``` Time difference of 29.9468 secs ``` ``` Time difference of 3.222129 secs ``` The performance improvement was around **950%**. Actually, this PR improves around **1200%**+ because this PR includes a small optimization about regular R DataFrame -> Spark DatFrame. See #22954 (comment) ### Limitations: For now, Arrow optimization with R does not support when the data is `raw`, and when user explicitly gives float type in the schema. They produce corrupt values. In this case, we decide to fall back to non-optimization code path. ## How was this patch tested? Small test was added. I manually forced to set this optimization `true` for _all_ R tests and they were _all_ passed (with few of fallback warnings). **TODOs:** - [x] Draft codes - [x] make the tests passed - [x] make the CRAN check pass - [x] Performance measurement - [x] Supportability investigation (for instance types) - [x] Wait for Arrow 0.12.0 release - [x] Fix and match it to Arrow 0.12.0 Closes #22954 from HyukjinKwon/r-arrow-createdataframe. Lead-authored-by: hyukjinkwon <gurwls223@apache.org> Co-authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent e71acd9 commit e8982ca

File tree

5 files changed

+239
-46
lines changed

5 files changed

+239
-46
lines changed

R/pkg/R/SQLContext.R

Lines changed: 136 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,70 @@ getDefaultSqlSource <- function() {
147147
l[["spark.sql.sources.default"]]
148148
}
149149

150+
writeToFileInArrow <- function(fileName, rdf, numPartitions) {
151+
requireNamespace1 <- requireNamespace
152+
153+
# R API in Arrow is not yet released in CRAN. CRAN requires to add the
154+
# package in requireNamespace at DESCRIPTION. Later, CRAN checks if the package is available
155+
# or not. Therefore, it works around by avoiding direct requireNamespace.
156+
# Currently, as of Arrow 0.12.0, it can be installed by install_github. See ARROW-3204.
157+
if (requireNamespace1("arrow", quietly = TRUE)) {
158+
record_batch <- get("record_batch", envir = asNamespace("arrow"), inherits = FALSE)
159+
RecordBatchStreamWriter <- get(
160+
"RecordBatchStreamWriter", envir = asNamespace("arrow"), inherits = FALSE)
161+
FileOutputStream <- get(
162+
"FileOutputStream", envir = asNamespace("arrow"), inherits = FALSE)
163+
164+
numPartitions <- if (!is.null(numPartitions)) {
165+
numToInt(numPartitions)
166+
} else {
167+
1
168+
}
169+
170+
rdf_slices <- if (numPartitions > 1) {
171+
split(rdf, makeSplits(numPartitions, nrow(rdf)))
172+
} else {
173+
list(rdf)
174+
}
175+
176+
stream_writer <- NULL
177+
tryCatch({
178+
for (rdf_slice in rdf_slices) {
179+
batch <- record_batch(rdf_slice)
180+
if (is.null(stream_writer)) {
181+
stream <- FileOutputStream(fileName)
182+
schema <- batch$schema
183+
stream_writer <- RecordBatchStreamWriter(stream, schema)
184+
}
185+
186+
stream_writer$write_batch(batch)
187+
}
188+
},
189+
finally = {
190+
if (!is.null(stream_writer)) {
191+
stream_writer$close()
192+
}
193+
})
194+
195+
} else {
196+
stop("'arrow' package should be installed.")
197+
}
198+
}
199+
200+
checkTypeRequirementForArrow <- function(dataHead, schema) {
201+
# Currenty Arrow optimization does not support raw for now.
202+
# Also, it does not support explicit float type set by users. It leads to
203+
# incorrect conversion. We will fall back to the path without Arrow optimization.
204+
if (any(sapply(dataHead, is.raw))) {
205+
stop("Arrow optimization with R DataFrame does not support raw type yet.")
206+
}
207+
if (inherits(schema, "structType")) {
208+
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "FloatType"))) {
209+
stop("Arrow optimization with R DataFrame does not support FloatType type yet.")
210+
}
211+
}
212+
}
213+
150214
#' Create a SparkDataFrame
151215
#'
152216
#' Converts R data.frame or list into SparkDataFrame.
@@ -172,36 +236,76 @@ getDefaultSqlSource <- function() {
172236
createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
173237
numPartitions = NULL) {
174238
sparkSession <- getSparkSession()
239+
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
240+
useArrow <- FALSE
241+
firstRow <- NULL
175242

176243
if (is.data.frame(data)) {
177-
# Convert data into a list of rows. Each row is a list.
178-
179-
# get the names of columns, they will be put into RDD
180-
if (is.null(schema)) {
181-
schema <- names(data)
182-
}
244+
# get the names of columns, they will be put into RDD
245+
if (is.null(schema)) {
246+
schema <- names(data)
247+
}
183248

184-
# get rid of factor type
185-
cleanCols <- function(x) {
186-
if (is.factor(x)) {
187-
as.character(x)
188-
} else {
189-
x
190-
}
249+
# get rid of factor type
250+
cleanCols <- function(x) {
251+
if (is.factor(x)) {
252+
as.character(x)
253+
} else {
254+
x
191255
}
256+
}
257+
data[] <- lapply(data, cleanCols)
258+
259+
args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
260+
if (arrowEnabled) {
261+
useArrow <- tryCatch({
262+
stopifnot(length(data) > 0)
263+
dataHead <- head(data, 1)
264+
checkTypeRequirementForArrow(data, schema)
265+
fileName <- tempfile(pattern = "sparwriteToFileInArrowk-arrow", fileext = ".tmp")
266+
tryCatch({
267+
writeToFileInArrow(fileName, data, numPartitions)
268+
jrddInArrow <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
269+
"readArrowStreamFromFile",
270+
sparkSession,
271+
fileName)
272+
},
273+
finally = {
274+
# File might not be created.
275+
suppressWarnings(file.remove(fileName))
276+
})
277+
278+
firstRow <- do.call(mapply, append(args, dataHead))[[1]]
279+
TRUE
280+
},
281+
error = function(e) {
282+
warning(paste0("createDataFrame attempted Arrow optimization because ",
283+
"'spark.sql.execution.arrow.enabled' is set to true; however, ",
284+
"failed, attempting non-optimization. Reason: ",
285+
e))
286+
FALSE
287+
})
288+
}
192289

290+
if (!useArrow) {
291+
# Convert data into a list of rows. Each row is a list.
193292
# drop factors and wrap lists
194-
data <- setNames(lapply(data, cleanCols), NULL)
293+
data <- setNames(as.list(data), NULL)
195294

196295
# check if all columns have supported type
197296
lapply(data, getInternalType)
198297

199298
# convert to rows
200-
args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
201299
data <- do.call(mapply, append(args, data))
300+
if (length(data) > 0) {
301+
firstRow <- data[[1]]
302+
}
303+
}
202304
}
203305

204-
if (is.list(data)) {
306+
if (useArrow) {
307+
rdd <- jrddInArrow
308+
} else if (is.list(data)) {
205309
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
206310
if (!is.null(numPartitions)) {
207311
rdd <- parallelize(sc, data, numSlices = numToInt(numPartitions))
@@ -215,14 +319,16 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
215319
}
216320

217321
if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) {
218-
row <- firstRDD(rdd)
322+
if (is.null(firstRow)) {
323+
firstRow <- firstRDD(rdd)
324+
}
219325
names <- if (is.null(schema)) {
220-
names(row)
326+
names(firstRow)
221327
} else {
222328
as.list(schema)
223329
}
224330
if (is.null(names)) {
225-
names <- lapply(1:length(row), function(x) {
331+
names <- lapply(1:length(firstRow), function(x) {
226332
paste("_", as.character(x), sep = "")
227333
})
228334
}
@@ -237,19 +343,24 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
237343
nn
238344
})
239345

240-
types <- lapply(row, infer_type)
241-
fields <- lapply(1:length(row), function(i) {
346+
types <- lapply(firstRow, infer_type)
347+
fields <- lapply(1:length(firstRow), function(i) {
242348
structField(names[[i]], types[[i]], TRUE)
243349
})
244350
schema <- do.call(structType, fields)
245351
}
246352

247353
stopifnot(class(schema) == "structType")
248354

249-
jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
250-
srdd <- callJMethod(jrdd, "rdd")
251-
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
252-
srdd, schema$jobj, sparkSession)
355+
if (useArrow) {
356+
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
357+
"toDataFrame", rdd, schema$jobj, sparkSession)
358+
} else {
359+
jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
360+
srdd <- callJMethod(jrdd, "rdd")
361+
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
362+
srdd, schema$jobj, sparkSession)
363+
}
253364
dataFrame(sdf)
254365
}
255366

R/pkg/R/context.R

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ objectFile <- function(sc, path, minPartitions = NULL) {
8181
RDD(jrdd, "byte")
8282
}
8383

84+
makeSplits <- function(numSerializedSlices, length) {
85+
# Generate the slice ids to put each row
86+
# For instance, for numSerializedSlices of 22, length of 50
87+
# [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
88+
# [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
89+
# Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
90+
# We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
91+
if (numSerializedSlices > 0) {
92+
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
93+
# nolint start
94+
start <- trunc((as.numeric(x) * length) / numSerializedSlices)
95+
end <- trunc(((as.numeric(x) + 1) * length) / numSerializedSlices)
96+
# nolint end
97+
rep(start, end - start)
98+
}))
99+
} else {
100+
1
101+
}
102+
}
103+
84104
#' Create an RDD from a homogeneous list or vector.
85105
#'
86106
#' This function creates an RDD from a local homogeneous list in R. The elements
@@ -143,25 +163,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
143163
# For large objects we make sure the size of each slice is also smaller than sizeLimit
144164
numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit)))
145165

146-
# Generate the slice ids to put each row
147-
# For instance, for numSerializedSlices of 22, length of 50
148-
# [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
149-
# [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
150-
# Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
151-
# We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
152-
splits <- if (numSerializedSlices > 0) {
153-
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
154-
# nolint start
155-
start <- trunc((as.numeric(x) * len) / numSerializedSlices)
156-
end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices)
157-
# nolint end
158-
rep(start, end - start)
159-
}))
160-
} else {
161-
1
162-
}
163-
164-
slices <- split(coll, splits)
166+
slices <- split(coll, makeSplits(numSerializedSlices, len))
165167

166168
# Serialize each slice: obtain a list of raws, or a list of lists (slices) of
167169
# 2-tuples of raws

R/pkg/tests/fulltests/test_sparkSQL.R

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,63 @@ test_that("create DataFrame from RDD", {
307307
unsetHiveContext()
308308
})
309309

310+
test_that("createDataFrame Arrow optimization", {
311+
skip_if_not_installed("arrow")
312+
313+
conf <- callJMethod(sparkSession, "conf")
314+
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
315+
316+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
317+
tryCatch({
318+
expected <- collect(createDataFrame(mtcars))
319+
},
320+
finally = {
321+
# Resetting the conf back to default value
322+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
323+
})
324+
325+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
326+
tryCatch({
327+
expect_equal(collect(createDataFrame(mtcars)), expected)
328+
},
329+
finally = {
330+
# Resetting the conf back to default value
331+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
332+
})
333+
})
334+
335+
test_that("createDataFrame Arrow optimization - type specification", {
336+
skip_if_not_installed("arrow")
337+
rdf <- data.frame(list(list(a = 1,
338+
b = "a",
339+
c = TRUE,
340+
d = 1.1,
341+
e = 1L,
342+
f = as.Date("1990-02-24"),
343+
g = as.POSIXct("1990-02-24 12:34:56"))))
344+
345+
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
346+
conf <- callJMethod(sparkSession, "conf")
347+
348+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
349+
tryCatch({
350+
expected <- collect(createDataFrame(rdf))
351+
},
352+
finally = {
353+
# Resetting the conf back to default value
354+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
355+
})
356+
357+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
358+
tryCatch({
359+
expect_equal(collect(createDataFrame(rdf)), expected)
360+
},
361+
finally = {
362+
# Resetting the conf back to default value
363+
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
364+
})
365+
})
366+
310367
test_that("read/write csv as DataFrame", {
311368
if (windows_with_hadoop()) {
312369
csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,8 +1286,9 @@ object SQLConf {
12861286
val ARROW_EXECUTION_ENABLED =
12871287
buildConf("spark.sql.execution.arrow.enabled")
12881288
.doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " +
1289-
"for use with pyspark.sql.DataFrame.toPandas, and " +
1290-
"pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " +
1289+
"for use with pyspark.sql.DataFrame.toPandas, " +
1290+
"pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame, " +
1291+
"and createDataFrame when its input is an R DataFrame. " +
12911292
"The following data types are unsupported: " +
12921293
"BinaryType, MapType, ArrayType of TimestampType, and nested StructType.")
12931294
.booleanConf

sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql._
3333
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericRowWithSchema}
3434
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
35+
import org.apache.spark.sql.execution.arrow.ArrowConverters
3536
import org.apache.spark.sql.execution.command.ShowTablesCommand
3637
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
3738
import org.apache.spark.sql.types._
@@ -237,4 +238,25 @@ private[sql] object SQLUtils extends Logging {
237238
def createArrayType(column: Column): ArrayType = {
238239
new ArrayType(ExprUtils.evalTypeExpr(column.expr), true)
239240
}
241+
242+
/**
243+
* R callable function to read a file in Arrow stream format and create an `RDD`
244+
* using each serialized ArrowRecordBatch as a partition.
245+
*/
246+
def readArrowStreamFromFile(
247+
sparkSession: SparkSession,
248+
filename: String): JavaRDD[Array[Byte]] = {
249+
ArrowConverters.readArrowStreamFromFile(sparkSession.sqlContext, filename)
250+
}
251+
252+
/**
253+
* R callable function to create a `DataFrame` from a `JavaRDD` of serialized
254+
* ArrowRecordBatches.
255+
*/
256+
def toDataFrame(
257+
arrowBatchRDD: JavaRDD[Array[Byte]],
258+
schema: StructType,
259+
sparkSession: SparkSession): DataFrame = {
260+
ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession.sqlContext)
261+
}
240262
}

0 commit comments

Comments
 (0)