Skip to content

Commit fd648bf

Browse files
zero323Felix Cheung
authored and
Felix Cheung
committed
[SPARK-20371][R] Add wrappers for collect_list and collect_set
## What changes were proposed in this pull request? Adds wrappers for `collect_list` and `collect_set`. ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 <zero323@users.noreply.github.com> Closes #17672 from zero323/SPARK-20371.
1 parent eb00378 commit fd648bf

File tree

4 files changed

+73
-0
lines changed

4 files changed

+73
-0
lines changed

R/pkg/NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ exportMethods("%in%",
203203
"cbrt",
204204
"ceil",
205205
"ceiling",
206+
"collect_list",
207+
"collect_set",
206208
"column",
207209
"concat",
208210
"concat_ws",

R/pkg/R/functions.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3705,3 +3705,43 @@ setMethod("create_map",
37053705
jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols)
37063706
column(jc)
37073707
})
3708+
3709+
#' collect_list
3710+
#'
3711+
#' Creates a list of objects with duplicates.
3712+
#'
3713+
#' @param x Column to compute on
3714+
#'
3715+
#' @rdname collect_list
3716+
#' @name collect_list
3717+
#' @family agg_funcs
3718+
#' @aliases collect_list,Column-method
3719+
#' @export
3720+
#' @examples \dontrun{collect_list(df$x)}
3721+
#' @note collect_list since 2.3.0
3722+
setMethod("collect_list",
3723+
signature(x = "Column"),
3724+
function(x) {
3725+
jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc)
3726+
column(jc)
3727+
})
3728+
3729+
#' collect_set
3730+
#'
3731+
#' Creates a list of objects with duplicate elements eliminated.
3732+
#'
3733+
#' @param x Column to compute on
3734+
#'
3735+
#' @rdname collect_set
3736+
#' @name collect_set
3737+
#' @family agg_funcs
3738+
#' @aliases collect_set,Column-method
3739+
#' @export
3740+
#' @examples \dontrun{collect_set(df$x)}
3741+
#' @note collect_set since 2.3.0
3742+
setMethod("collect_set",
3743+
signature(x = "Column"),
3744+
function(x) {
3745+
jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc)
3746+
column(jc)
3747+
})

R/pkg/R/generics.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
918918
#' @export
919919
setGeneric("ceil", function(x) { standardGeneric("ceil") })
920920

921+
#' @rdname collect_list
922+
#' @export
923+
setGeneric("collect_list", function(x) { standardGeneric("collect_list") })
924+
925+
#' @rdname collect_set
926+
#' @export
927+
setGeneric("collect_set", function(x) { standardGeneric("collect_set") })
928+
921929
#' @rdname column
922930
#' @export
923931
setGeneric("column", function(x) { standardGeneric("column") })
@@ -1358,6 +1366,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
13581366
#' @export
13591367
setGeneric("year", function(x) { standardGeneric("year") })
13601368

1369+
13611370
###################### Spark.ML Methods ##########################
13621371

13631372
#' @rdname fitted

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,28 @@ test_that("group by, agg functions", {
17311731
expect_true(abs(sd(1:2) - 0.7071068) < 1e-6)
17321732
expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6)
17331733

1734+
# Test collect_list and collect_set
1735+
gd3_collections_local <- collect(
1736+
agg(gd3, collect_set(df8$age), collect_list(df8$age))
1737+
)
1738+
1739+
expect_equal(
1740+
unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]),
1741+
c(30)
1742+
)
1743+
1744+
expect_equal(
1745+
unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]),
1746+
c(30, 30)
1747+
)
1748+
1749+
expect_equal(
1750+
sort(unlist(
1751+
gd3_collections_local[gd3_collections_local$name == "Justin", 3]
1752+
)),
1753+
c(1, 19)
1754+
)
1755+
17341756
unlink(jsonPath2)
17351757
unlink(jsonPath3)
17361758
})

0 commit comments

Comments
 (0)