Skip to content

Commit 1e3526c

Browse files
gatorsmileshivaram
authored andcommitted
[SPARK-12158][SPARKR][SQL] Fix 'sample' functions that break R unit test cases
The existing sample functions miss the parameter `seed`, however, the corresponding function interface in `generics` has such a parameter. Thus, although the function caller can call the function with the 'seed', we are not using the value. This could cause SparkR unit tests failed. For example, I hit it in another PR: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/47213/consoleFull Author: gatorsmile <gatorsmile@gmail.com> Closes #10160 from gatorsmile/sampleR.
1 parent 1e799d6 commit 1e3526c

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

R/pkg/R/DataFrame.R

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ setMethod("unique",
662662
#' @param x A SparkSQL DataFrame
663663
#' @param withReplacement Sampling with replacement or not
664664
#' @param fraction The (rough) sample target fraction
665+
#' @param seed Randomness seed value
665666
#'
666667
#' @family DataFrame functions
667668
#' @rdname sample
@@ -677,13 +678,17 @@ setMethod("unique",
677678
#' collect(sample(df, TRUE, 0.5))
678679
#'}
679680
setMethod("sample",
680-
# TODO : Figure out how to send integer as java.lang.Long to JVM so
681-
# we can send seed as an argument through callJMethod
682681
signature(x = "DataFrame", withReplacement = "logical",
683682
fraction = "numeric"),
684-
function(x, withReplacement, fraction) {
683+
function(x, withReplacement, fraction, seed) {
685684
if (fraction < 0.0) stop(cat("Negative fraction value:", fraction))
686-
sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
685+
if (!missing(seed)) {
686+
# TODO : Figure out how to send integer as java.lang.Long to JVM so
687+
# we can send seed as an argument through callJMethod
688+
sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed))
689+
} else {
690+
sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
691+
}
687692
dataFrame(sdf)
688693
})
689694

@@ -692,8 +697,8 @@ setMethod("sample",
692697
setMethod("sample_frac",
693698
signature(x = "DataFrame", withReplacement = "logical",
694699
fraction = "numeric"),
695-
function(x, withReplacement, fraction) {
696-
sample(x, withReplacement, fraction)
700+
function(x, withReplacement, fraction, seed) {
701+
sample(x, withReplacement, fraction, seed)
697702
})
698703

699704
#' nrow

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,10 @@ test_that("sample on a DataFrame", {
724724
sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result
725725
expect_true(count(sampled2) < 3)
726726

727+
count1 <- count(sample(df, FALSE, 0.1, 0))
728+
count2 <- count(sample(df, FALSE, 0.1, 0))
729+
expect_equal(count1, count2)
730+
727731
# Also test sample_frac
728732
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
729733
expect_true(count(sampled3) < 3)

0 commit comments

Comments
 (0)