Skip to content

Commit 2491cf1

Browse files
committed
[SPARK-32747][R][TESTS] Deduplicate configuration set/unset in test_sparkSQL_arrow.R
### What changes were proposed in this pull request? This PR proposes to deduplicate configuration set/unset in `test_sparkSQL_arrow.R`. Setting `spark.sql.execution.arrow.sparkr.enabled` can be globally done instead of doing it in each test case. ### Why are the changes needed? To duduplicate the codes. ### Does this PR introduce _any_ user-facing change? No, dev-only ### How was this patch tested? Manually ran the tests. Closes #29592 from HyukjinKwon/SPARK-32747. Authored-by: HyukjinKwon <gurwls223@apache.org> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
1 parent 5574734 commit 2491cf1

File tree

1 file changed

+60
-141
lines changed

1 file changed

+60
-141
lines changed

R/pkg/tests/fulltests/test_sparkSQL_arrow.R

Lines changed: 60 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ library(testthat)
1919

2020
context("SparkSQL Arrow optimization")
2121

22-
sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
22+
sparkSession <- sparkR.session(
23+
master = sparkRTestMaster,
24+
enableHiveSupport = FALSE,
25+
sparkConfig = list(spark.sql.execution.arrow.sparkr.enabled = "true"))
2326

2427
test_that("createDataFrame/collect Arrow optimization", {
2528
skip_if_not_installed("arrow")
@@ -35,29 +38,13 @@ test_that("createDataFrame/collect Arrow optimization", {
3538
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
3639
})
3740

38-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
39-
tryCatch({
40-
expect_equal(collect(createDataFrame(mtcars)), expected)
41-
},
42-
finally = {
43-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
44-
})
41+
expect_equal(collect(createDataFrame(mtcars)), expected)
4542
})
4643

4744
test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", {
4845
skip_if_not_installed("arrow")
49-
50-
conf <- callJMethod(sparkSession, "conf")
51-
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]]
52-
53-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
54-
tryCatch({
55-
expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)),
56-
collect(createDataFrame(mtcars, numPartitions = 1)))
57-
},
58-
finally = {
59-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
60-
})
46+
expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)),
47+
collect(createDataFrame(mtcars, numPartitions = 1)))
6148
})
6249

6350
test_that("createDataFrame/collect Arrow optimization - type specification", {
@@ -81,13 +68,7 @@ test_that("createDataFrame/collect Arrow optimization - type specification", {
8168
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
8269
})
8370

84-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
85-
tryCatch({
86-
expect_equal(collect(createDataFrame(rdf)), expected)
87-
},
88-
finally = {
89-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
90-
})
71+
expect_equal(collect(createDataFrame(rdf)), expected)
9172
})
9273

9374
test_that("dapply() Arrow optimization", {
@@ -98,36 +79,30 @@ test_that("dapply() Arrow optimization", {
9879
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]]
9980

10081
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "false")
101-
tryCatch({
102-
ret <- dapply(df,
103-
function(rdf) {
104-
stopifnot(is.data.frame(rdf))
105-
rdf
106-
},
107-
schema(df))
108-
expected <- collect(ret)
109-
},
110-
finally = {
111-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
112-
})
113-
114-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
11582
tryCatch({
11683
ret <- dapply(df,
11784
function(rdf) {
11885
stopifnot(is.data.frame(rdf))
119-
# mtcars' hp is more then 50.
120-
stopifnot(all(rdf$hp > 50))
12186
rdf
12287
},
12388
schema(df))
124-
actual <- collect(ret)
125-
expect_equal(actual, expected)
126-
expect_equal(count(ret), nrow(mtcars))
89+
expected <- collect(ret)
12790
},
12891
finally = {
12992
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
13093
})
94+
95+
ret <- dapply(df,
96+
function(rdf) {
97+
stopifnot(is.data.frame(rdf))
98+
# mtcars' hp is more then 50.
99+
stopifnot(all(rdf$hp > 50))
100+
rdf
101+
},
102+
schema(df))
103+
actual <- collect(ret)
104+
expect_equal(actual, expected)
105+
expect_equal(count(ret), nrow(mtcars))
131106
})
132107

133108
test_that("dapply() Arrow optimization - type specification", {
@@ -154,34 +129,18 @@ test_that("dapply() Arrow optimization - type specification", {
154129
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
155130
})
156131

157-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
158-
tryCatch({
159-
ret <- dapply(df, function(rdf) { rdf }, schema(df))
160-
actual <- collect(ret)
161-
expect_equal(actual, expected)
162-
},
163-
finally = {
164-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
165-
})
132+
ret <- dapply(df, function(rdf) { rdf }, schema(df))
133+
actual <- collect(ret)
134+
expect_equal(actual, expected)
166135
})
167136

168137
test_that("dapply() Arrow optimization - type specification (date and timestamp)", {
169138
skip_if_not_installed("arrow")
170139
rdf <- data.frame(list(list(a = as.Date("1990-02-24"),
171140
b = as.POSIXct("1990-02-24 12:34:56"))))
172141
df <- createDataFrame(rdf)
173-
174-
conf <- callJMethod(sparkSession, "conf")
175-
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]]
176-
177-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
178-
tryCatch({
179-
ret <- dapply(df, function(rdf) { rdf }, schema(df))
180-
expect_equal(collect(ret), rdf)
181-
},
182-
finally = {
183-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
184-
})
142+
ret <- dapply(df, function(rdf) { rdf }, schema(df))
143+
expect_equal(collect(ret), rdf)
185144
})
186145

187146
test_that("gapply() Arrow optimization", {
@@ -209,28 +168,22 @@ test_that("gapply() Arrow optimization", {
209168
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
210169
})
211170

212-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
213-
tryCatch({
214-
ret <- gapply(df,
215-
"gear",
216-
function(key, grouped) {
217-
if (length(key) > 0) {
218-
stopifnot(is.numeric(key[[1]]))
219-
}
220-
stopifnot(is.data.frame(grouped))
221-
stopifnot(length(colnames(grouped)) == 11)
222-
# mtcars' hp is more then 50.
223-
stopifnot(all(grouped$hp > 50))
224-
grouped
225-
},
226-
schema(df))
227-
actual <- collect(ret)
228-
expect_equal(actual, expected)
229-
expect_equal(count(ret), nrow(mtcars))
230-
},
231-
finally = {
232-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
233-
})
171+
ret <- gapply(df,
172+
"gear",
173+
function(key, grouped) {
174+
if (length(key) > 0) {
175+
stopifnot(is.numeric(key[[1]]))
176+
}
177+
stopifnot(is.data.frame(grouped))
178+
stopifnot(length(colnames(grouped)) == 11)
179+
# mtcars' hp is more then 50.
180+
stopifnot(all(grouped$hp > 50))
181+
grouped
182+
},
183+
schema(df))
184+
actual <- collect(ret)
185+
expect_equal(actual, expected)
186+
expect_equal(count(ret), nrow(mtcars))
234187
})
235188

236189
test_that("gapply() Arrow optimization - type specification", {
@@ -250,84 +203,50 @@ test_that("gapply() Arrow optimization - type specification", {
250203
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "false")
251204
tryCatch({
252205
ret <- gapply(df,
253-
"a",
254-
function(key, grouped) { grouped }, schema(df))
206+
"a",
207+
function(key, grouped) { grouped }, schema(df))
255208
expected <- collect(ret)
256209
},
257210
finally = {
258211
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
259212
})
260213

261-
262-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
263-
tryCatch({
264-
ret <- gapply(df,
265-
"a",
266-
function(key, grouped) { grouped }, schema(df))
267-
actual <- collect(ret)
268-
expect_equal(actual, expected)
269-
},
270-
finally = {
271-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
272-
})
214+
ret <- gapply(df,
215+
"a",
216+
function(key, grouped) { grouped }, schema(df))
217+
actual <- collect(ret)
218+
expect_equal(actual, expected)
273219
})
274220

275221
test_that("gapply() Arrow optimization - type specification (date and timestamp)", {
276222
skip_if_not_installed("arrow")
277223
rdf <- data.frame(list(list(a = as.Date("1990-02-24"),
278224
b = as.POSIXct("1990-02-24 12:34:56"))))
279225
df <- createDataFrame(rdf)
280-
281-
conf <- callJMethod(sparkSession, "conf")
282-
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]]
283-
284-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
285-
tryCatch({
286-
ret <- gapply(df,
287-
"a",
288-
function(key, grouped) { grouped }, schema(df))
289-
expect_equal(collect(ret), rdf)
290-
},
291-
finally = {
292-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
293-
})
226+
ret <- gapply(df,
227+
"a",
228+
function(key, grouped) { grouped }, schema(df))
229+
expect_equal(collect(ret), rdf)
294230
})
295231

296232
test_that("Arrow optimization - unsupported types", {
297233
skip_if_not_installed("arrow")
298234

299-
conf <- callJMethod(sparkSession, "conf")
300-
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]]
301-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
302-
tryCatch({
303-
expect_error(checkSchemaInArrow(structType("a FLOAT")), "not support float type")
304-
expect_error(checkSchemaInArrow(structType("a BINARY")), "not support binary type")
305-
expect_error(checkSchemaInArrow(structType("a ARRAY<INT>")), "not support array type")
306-
expect_error(checkSchemaInArrow(structType("a MAP<INT, INT>")), "not support map type")
307-
expect_error(checkSchemaInArrow(structType("a STRUCT<a: INT>")),
308-
"not support nested struct type")
309-
},
310-
finally = {
311-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
312-
})
235+
expect_error(checkSchemaInArrow(structType("a FLOAT")), "not support float type")
236+
expect_error(checkSchemaInArrow(structType("a BINARY")), "not support binary type")
237+
expect_error(checkSchemaInArrow(structType("a ARRAY<INT>")), "not support array type")
238+
expect_error(checkSchemaInArrow(structType("a MAP<INT, INT>")), "not support map type")
239+
expect_error(checkSchemaInArrow(structType("a STRUCT<a: INT>")),
240+
"not support nested struct type")
313241
})
314242

315243
test_that("SPARK-32478: gapply() Arrow optimization - error message for schema mismatch", {
316244
skip_if_not_installed("arrow")
317245
df <- createDataFrame(list(list(a = 1L, b = "a")))
318246

319-
conf <- callJMethod(sparkSession, "conf")
320-
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]]
321-
322-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true")
323-
tryCatch({
324-
expect_error(
247+
expect_error(
325248
count(gapply(df, "a", function(key, group) { group }, structType("a int, b int"))),
326249
"expected IntegerType, IntegerType, got IntegerType, StringType")
327-
},
328-
finally = {
329-
callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled)
330-
})
331250
})
332251

333252
sparkR.session.stop()

0 commit comments

Comments
 (0)