@@ -19,7 +19,10 @@ library(testthat)
1919
2020context(" SparkSQL Arrow optimization" )
2121
22- sparkSession <- sparkR.session(master = sparkRTestMaster , enableHiveSupport = FALSE )
22+ sparkSession <- sparkR.session(
23+ master = sparkRTestMaster ,
24+ enableHiveSupport = FALSE ,
25+ sparkConfig = list (spark.sql.execution.arrow.sparkr.enabled = " true" ))
2326
2427test_that(" createDataFrame/collect Arrow optimization" , {
2528 skip_if_not_installed(" arrow" )
@@ -35,29 +38,13 @@ test_that("createDataFrame/collect Arrow optimization", {
3538 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
3639 })
3740
38- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
39- tryCatch({
40- expect_equal(collect(createDataFrame(mtcars )), expected )
41- },
42- finally = {
43- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
44- })
41+ expect_equal(collect(createDataFrame(mtcars )), expected )
4542})
4643
4744test_that(" createDataFrame/collect Arrow optimization - many partitions (partition order test)" , {
4845 skip_if_not_installed(" arrow" )
49-
50- conf <- callJMethod(sparkSession , " conf" )
51- arrowEnabled <- sparkR.conf(" spark.sql.execution.arrow.sparkr.enabled" )[[1 ]]
52-
53- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
54- tryCatch({
55- expect_equal(collect(createDataFrame(mtcars , numPartitions = 32 )),
56- collect(createDataFrame(mtcars , numPartitions = 1 )))
57- },
58- finally = {
59- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
60- })
46+ expect_equal(collect(createDataFrame(mtcars , numPartitions = 32 )),
47+ collect(createDataFrame(mtcars , numPartitions = 1 )))
6148})
6249
6350test_that(" createDataFrame/collect Arrow optimization - type specification" , {
@@ -81,13 +68,7 @@ test_that("createDataFrame/collect Arrow optimization - type specification", {
8168 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
8269 })
8370
84- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
85- tryCatch({
86- expect_equal(collect(createDataFrame(rdf )), expected )
87- },
88- finally = {
89- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
90- })
71+ expect_equal(collect(createDataFrame(rdf )), expected )
9172})
9273
9374test_that(" dapply() Arrow optimization" , {
@@ -98,36 +79,30 @@ test_that("dapply() Arrow optimization", {
9879 arrowEnabled <- sparkR.conf(" spark.sql.execution.arrow.sparkr.enabled" )[[1 ]]
9980
10081 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " false" )
101- tryCatch({
102- ret <- dapply(df ,
103- function (rdf ) {
104- stopifnot(is.data.frame(rdf ))
105- rdf
106- },
107- schema(df ))
108- expected <- collect(ret )
109- },
110- finally = {
111- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
112- })
113-
114- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
11582 tryCatch({
11683 ret <- dapply(df ,
11784 function (rdf ) {
11885 stopifnot(is.data.frame(rdf ))
119- # mtcars' hp is more then 50.
120- stopifnot(all(rdf $ hp > 50 ))
12186 rdf
12287 },
12388 schema(df ))
124- actual <- collect(ret )
125- expect_equal(actual , expected )
126- expect_equal(count(ret ), nrow(mtcars ))
89+ expected <- collect(ret )
12790 },
12891 finally = {
12992 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
13093 })
94+
95+ ret <- dapply(df ,
96+ function (rdf ) {
97+ stopifnot(is.data.frame(rdf ))
98+ # mtcars' hp is more then 50.
99+ stopifnot(all(rdf $ hp > 50 ))
100+ rdf
101+ },
102+ schema(df ))
103+ actual <- collect(ret )
104+ expect_equal(actual , expected )
105+ expect_equal(count(ret ), nrow(mtcars ))
131106})
132107
133108test_that(" dapply() Arrow optimization - type specification" , {
@@ -154,34 +129,18 @@ test_that("dapply() Arrow optimization - type specification", {
154129 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
155130 })
156131
157- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
158- tryCatch({
159- ret <- dapply(df , function (rdf ) { rdf }, schema(df ))
160- actual <- collect(ret )
161- expect_equal(actual , expected )
162- },
163- finally = {
164- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
165- })
132+ ret <- dapply(df , function (rdf ) { rdf }, schema(df ))
133+ actual <- collect(ret )
134+ expect_equal(actual , expected )
166135})
167136
168137test_that(" dapply() Arrow optimization - type specification (date and timestamp)" , {
169138 skip_if_not_installed(" arrow" )
170139 rdf <- data.frame (list (list (a = as.Date(" 1990-02-24" ),
171140 b = as.POSIXct(" 1990-02-24 12:34:56" ))))
172141 df <- createDataFrame(rdf )
173-
174- conf <- callJMethod(sparkSession , " conf" )
175- arrowEnabled <- sparkR.conf(" spark.sql.execution.arrow.sparkr.enabled" )[[1 ]]
176-
177- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
178- tryCatch({
179- ret <- dapply(df , function (rdf ) { rdf }, schema(df ))
180- expect_equal(collect(ret ), rdf )
181- },
182- finally = {
183- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
184- })
142+ ret <- dapply(df , function (rdf ) { rdf }, schema(df ))
143+ expect_equal(collect(ret ), rdf )
185144})
186145
187146test_that(" gapply() Arrow optimization" , {
@@ -209,28 +168,22 @@ test_that("gapply() Arrow optimization", {
209168 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
210169 })
211170
212- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
213- tryCatch({
214- ret <- gapply(df ,
215- " gear" ,
216- function (key , grouped ) {
217- if (length(key ) > 0 ) {
218- stopifnot(is.numeric(key [[1 ]]))
219- }
220- stopifnot(is.data.frame(grouped ))
221- stopifnot(length(colnames(grouped )) == 11 )
222- # mtcars' hp is more then 50.
223- stopifnot(all(grouped $ hp > 50 ))
224- grouped
225- },
226- schema(df ))
227- actual <- collect(ret )
228- expect_equal(actual , expected )
229- expect_equal(count(ret ), nrow(mtcars ))
230- },
231- finally = {
232- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
233- })
171+ ret <- gapply(df ,
172+ " gear" ,
173+ function (key , grouped ) {
174+ if (length(key ) > 0 ) {
175+ stopifnot(is.numeric(key [[1 ]]))
176+ }
177+ stopifnot(is.data.frame(grouped ))
178+ stopifnot(length(colnames(grouped )) == 11 )
179+ # mtcars' hp is more then 50.
180+ stopifnot(all(grouped $ hp > 50 ))
181+ grouped
182+ },
183+ schema(df ))
184+ actual <- collect(ret )
185+ expect_equal(actual , expected )
186+ expect_equal(count(ret ), nrow(mtcars ))
234187})
235188
236189test_that(" gapply() Arrow optimization - type specification" , {
@@ -250,84 +203,50 @@ test_that("gapply() Arrow optimization - type specification", {
250203 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " false" )
251204 tryCatch({
252205 ret <- gapply(df ,
253- " a" ,
254- function (key , grouped ) { grouped }, schema(df ))
206+ " a" ,
207+ function (key , grouped ) { grouped }, schema(df ))
255208 expected <- collect(ret )
256209 },
257210 finally = {
258211 callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
259212 })
260213
261-
262- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
263- tryCatch({
264- ret <- gapply(df ,
265- " a" ,
266- function (key , grouped ) { grouped }, schema(df ))
267- actual <- collect(ret )
268- expect_equal(actual , expected )
269- },
270- finally = {
271- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
272- })
214+ ret <- gapply(df ,
215+ " a" ,
216+ function (key , grouped ) { grouped }, schema(df ))
217+ actual <- collect(ret )
218+ expect_equal(actual , expected )
273219})
274220
275221test_that(" gapply() Arrow optimization - type specification (date and timestamp)" , {
276222 skip_if_not_installed(" arrow" )
277223 rdf <- data.frame (list (list (a = as.Date(" 1990-02-24" ),
278224 b = as.POSIXct(" 1990-02-24 12:34:56" ))))
279225 df <- createDataFrame(rdf )
280-
281- conf <- callJMethod(sparkSession , " conf" )
282- arrowEnabled <- sparkR.conf(" spark.sql.execution.arrow.sparkr.enabled" )[[1 ]]
283-
284- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
285- tryCatch({
286- ret <- gapply(df ,
287- " a" ,
288- function (key , grouped ) { grouped }, schema(df ))
289- expect_equal(collect(ret ), rdf )
290- },
291- finally = {
292- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
293- })
226+ ret <- gapply(df ,
227+ " a" ,
228+ function (key , grouped ) { grouped }, schema(df ))
229+ expect_equal(collect(ret ), rdf )
294230})
295231
296232test_that(" Arrow optimization - unsupported types" , {
297233 skip_if_not_installed(" arrow" )
298234
299- conf <- callJMethod(sparkSession , " conf" )
300- arrowEnabled <- sparkR.conf(" spark.sql.execution.arrow.sparkr.enabled" )[[1 ]]
301- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
302- tryCatch({
303- expect_error(checkSchemaInArrow(structType(" a FLOAT" )), " not support float type" )
304- expect_error(checkSchemaInArrow(structType(" a BINARY" )), " not support binary type" )
305- expect_error(checkSchemaInArrow(structType(" a ARRAY<INT>" )), " not support array type" )
306- expect_error(checkSchemaInArrow(structType(" a MAP<INT, INT>" )), " not support map type" )
307- expect_error(checkSchemaInArrow(structType(" a STRUCT<a: INT>" )),
308- " not support nested struct type" )
309- },
310- finally = {
311- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
312- })
235+ expect_error(checkSchemaInArrow(structType(" a FLOAT" )), " not support float type" )
236+ expect_error(checkSchemaInArrow(structType(" a BINARY" )), " not support binary type" )
237+ expect_error(checkSchemaInArrow(structType(" a ARRAY<INT>" )), " not support array type" )
238+ expect_error(checkSchemaInArrow(structType(" a MAP<INT, INT>" )), " not support map type" )
239+ expect_error(checkSchemaInArrow(structType(" a STRUCT<a: INT>" )),
240+ " not support nested struct type" )
313241})
314242
315243test_that(" SPARK-32478: gapply() Arrow optimization - error message for schema mismatch" , {
316244 skip_if_not_installed(" arrow" )
317245 df <- createDataFrame(list (list (a = 1L , b = " a" )))
318246
319- conf <- callJMethod(sparkSession , " conf" )
320- arrowEnabled <- sparkR.conf(" spark.sql.execution.arrow.sparkr.enabled" )[[1 ]]
321-
322- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , " true" )
323- tryCatch({
324- expect_error(
247+ expect_error(
325248 count(gapply(df , " a" , function (key , group ) { group }, structType(" a int, b int" ))),
326249 " expected IntegerType, IntegerType, got IntegerType, StringType" )
327- },
328- finally = {
329- callJMethod(conf , " set" , " spark.sql.execution.arrow.sparkr.enabled" , arrowEnabled )
330- })
331250})
332251
333252sparkR.session.stop()
0 commit comments