Skip to content

Commit 4d535d1

Browse files
yanboliangshivaram
authored andcommitted
[SPARK-13389][SPARKR] SparkR support first/last with ignore NAs
## What changes were proposed in this pull request? SparkR support first/last with ignore NAs cc sun-rui felixcheung shivaram ## How was the this patch tested? unit tests Author: Yanbo Liang <ybliang8@gmail.com> Closes #11267 from yanboliang/spark-13389.
1 parent c3a6269 commit 4d535d1

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

R/pkg/R/functions.R

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -536,15 +536,27 @@ setMethod("factorial",
536536
#'
537537
#' Aggregate function: returns the first value in a group.
538538
#'
539+
#' The function by default returns the first values it sees. It will return the first non-missing
540+
#' value it sees when na.rm is set to true. If all values are missing, then NA is returned.
541+
#'
539542
#' @rdname first
540543
#' @name first
541544
#' @family agg_funcs
542545
#' @export
543-
#' @examples \dontrun{first(df$c)}
546+
#' @examples
547+
#' \dontrun{
548+
#' first(df$c)
549+
#' first(df$c, TRUE)
550+
#' }
544551
setMethod("first",
545-
signature(x = "Column"),
546-
function(x) {
547-
jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc)
552+
signature(x = "characterOrColumn"),
553+
function(x, na.rm = FALSE) {
554+
col <- if (class(x) == "Column") {
555+
x@jc
556+
} else {
557+
x
558+
}
559+
jc <- callJStatic("org.apache.spark.sql.functions", "first", col, na.rm)
548560
column(jc)
549561
})
550562

@@ -663,15 +675,27 @@ setMethod("kurtosis",
663675
#'
664676
#' Aggregate function: returns the last value in a group.
665677
#'
678+
#' The function by default returns the last values it sees. It will return the last non-missing
679+
#' value it sees when na.rm is set to true. If all values are missing, then NA is returned.
680+
#'
666681
#' @rdname last
667682
#' @name last
668683
#' @family agg_funcs
669684
#' @export
670-
#' @examples \dontrun{last(df$c)}
685+
#' @examples
686+
#' \dontrun{
687+
#' last(df$c)
688+
#' last(df$c, TRUE)
689+
#' }
671690
setMethod("last",
672-
signature(x = "Column"),
673-
function(x) {
674-
jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc)
691+
signature(x = "characterOrColumn"),
692+
function(x, na.rm = FALSE) {
693+
col <- if (class(x) == "Column") {
694+
x@jc
695+
} else {
696+
x
697+
}
698+
jc <- callJStatic("org.apache.spark.sql.functions", "last", col, na.rm)
675699
column(jc)
676700
})
677701

R/pkg/R/generics.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") })
8484

8585
# @rdname first
8686
# @export
87-
setGeneric("first", function(x) { standardGeneric("first") })
87+
setGeneric("first", function(x, ...) { standardGeneric("first") })
8888

8989
# @rdname flatMap
9090
# @export
@@ -889,7 +889,7 @@ setGeneric("lag", function(x, ...) { standardGeneric("lag") })
889889

890890
#' @rdname last
891891
#' @export
892-
setGeneric("last", function(x) { standardGeneric("last") })
892+
setGeneric("last", function(x, ...) { standardGeneric("last") })
893893

894894
#' @rdname last_day
895895
#' @export

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,17 @@ test_that("column functions", {
10761076
result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8")))
10771077
expect_equal(result[[1]][[1]], bytes)
10781078
expect_equal(result[[2]], markUtf8("大千世界"))
1079+
1080+
# Test first(), last()
1081+
df <- read.json(sqlContext, jsonPath)
1082+
expect_equal(collect(select(df, first(df$age)))[[1]], NA)
1083+
expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30)
1084+
expect_equal(collect(select(df, first("age")))[[1]], NA)
1085+
expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30)
1086+
expect_equal(collect(select(df, last(df$age)))[[1]], 19)
1087+
expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19)
1088+
expect_equal(collect(select(df, last("age")))[[1]], 19)
1089+
expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19)
10791090
})
10801091

10811092
test_that("column binary mathfunctions", {

0 commit comments

Comments
 (0)