Skip to content

Commit bcb0bf5

Browse files
committed
Merge pull request apache#180 from davies/group
[SPARKR-191] groupBy and agg() API for DataFrame
2 parents 4d0fb56 + 9dd6a5a commit bcb0bf5

File tree

8 files changed

+156
-13
lines changed

8 files changed

+156
-13
lines changed

pkg/DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Collate:
2020
'RDD.R'
2121
'pairRDD.R'
2222
'column.R'
23+
'group.R'
2324
'DataFrame.R'
2425
'SQLContext.R'
2526
'broadcast.R'

pkg/NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ exportMethods("columns",
9090
"dtypes",
9191
"explain",
9292
"filter",
93+
"groupBy",
9394
"head",
9495
"isLocal",
9596
"limit",
@@ -134,6 +135,9 @@ exportMethods("asc",
134135
"countDistinct",
135136
"sumDistinct")
136137

138+
exportClasses("GroupedData")
139+
exportMethods("agg")
140+
137141
export("cacheTable",
138142
"clearCache",
139143
"createExternalTable",

pkg/R/DataFrame.R

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# DataFrame.R - DataFrame class and methods implemented in S4 OO classes
22

3-
#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R
3+
#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R
44
NULL
55

66
setOldClass("jobj")
@@ -663,6 +663,33 @@ setMethod("toRDD",
663663
})
664664
})
665665

666+
#' GroupBy
667+
#'
668+
#' Groups the DataFrame using the specified columns, so we can run aggregation on them.
669+
#'
670+
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
671+
672+
setMethod("groupBy",
673+
signature(x = "DataFrame"),
674+
function(x, ...) {
675+
cols <- list(...)
676+
if (length(cols) >= 1 && class(cols[[1]]) == "character") {
677+
sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1]))
678+
} else {
679+
jcol <- lapply(cols, function(c) { c@jc })
680+
sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
681+
}
682+
groupedData(sgd)
683+
})
684+
685+
686+
setMethod("agg",
687+
signature(x = "DataFrame"),
688+
function(x, ...) {
689+
agg(groupBy(x), ...)
690+
})
691+
692+
666693
############################## RDD Map Functions ##################################
667694
# All of the following functions mirror the existing RDD map functions, #
668695
# but allow for use with DataFrames by first converting to an RRDD before calling #

pkg/R/column.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ createMethods <- function() {
7373
createOperator(op)
7474
}
7575

76-
setGeneric("avg", function(x) { standardGeneric("avg") })
76+
setGeneric("avg", function(x, ...) { standardGeneric("avg") })
7777
setGeneric("last", function(x) { standardGeneric("last") })
7878
setGeneric("lower", function(x) { standardGeneric("lower") })
7979
setGeneric("upper", function(x) { standardGeneric("upper") })

pkg/R/group.R

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
############################## GroupedData ########################################
2+
3+
setClass("GroupedData",
4+
slots = list(env = "environment",
5+
sgd = "jobj"))
6+
7+
setMethod("initialize", "GroupedData", function(.Object, sgd) {
8+
.Object@env <- new.env()
9+
.Object@sgd <- sgd
10+
.Object
11+
})
12+
13+
groupedData <- function(sgd) {
14+
new("GroupedData", sgd)
15+
}
16+
17+
setMethod("count",
18+
signature(x = "GroupedData"),
19+
function(x) {
20+
dataFrame(callJMethod(x@sgd, "count"))
21+
})
22+
23+
#' Agg
24+
#'
25+
#' Aggregates on the entire DataFrame without groups.
26+
#'
27+
#' df2 <- agg(df, <column> = <aggFunction>)
28+
#' df2 <- agg(df, newColName = aggFunction(column))
29+
#' @examples
30+
#' \dontrun{
31+
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
32+
#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
33+
#' }
34+
setGeneric("agg", function (x, ...) { standardGeneric("agg") })
35+
36+
setMethod("agg",
37+
signature(x = "GroupedData"),
38+
function(x, ...) {
39+
cols = list(...)
40+
stopifnot(length(cols) > 0)
41+
if (is.character(cols[[1]])) {
42+
cols <- varargsToEnv(...)
43+
sdf <- callJMethod(x@sgd, "agg", cols)
44+
} else if (class(cols[[1]]) == "Column") {
45+
ns <- names(cols)
46+
if (!is.null(ns)) {
47+
for (n in ns) {
48+
if (n != "") {
49+
cols[[n]] = alias(cols[[n]], n)
50+
}
51+
}
52+
}
53+
jcols <- lapply(cols, function(c) { c@jc })
54+
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
55+
} else {
56+
stop("agg can only support Column or character")
57+
}
58+
dataFrame(sdf)
59+
})
60+
61+
#' sum/mean/avg/min/max
62+
63+
methods <- c("sum", "mean", "avg", "min", "max")
64+
65+
createMethod <- function(name) {
66+
setMethod(name,
67+
signature(x = "GroupedData"),
68+
function(x, ...) {
69+
sdf <- callJMethod(x@sgd, name, toSeq(...))
70+
dataFrame(sdf)
71+
})
72+
}
73+
74+
createMethods <- function() {
75+
for (name in methods) {
76+
createMethod(name)
77+
}
78+
}
79+
80+
createMethods()
81+

pkg/R/utils.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,3 @@ toSeq <- function(...) {
370370
listToSeq <- function(l) {
371371
callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "toSeq", l)
372372
}
373-

pkg/inst/tests/test_sparkSQL.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,31 @@ test_that("column functions", {
329329
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
330330
})
331331

332+
test_that("group by", {
333+
df <- jsonFile(sqlCtx, jsonPath)
334+
df1 <- agg(df, name = "max", age = "sum")
335+
expect_true(1 == count(df1))
336+
df1 <- agg(df, age2 = max(df$age))
337+
expect_true(1 == count(df1))
338+
expect_true(columns(df1) == c("age2"))
339+
340+
gd <- groupBy(df, "name")
341+
expect_true(inherits(gd, "GroupedData"))
342+
df2 <- count(gd)
343+
expect_true(inherits(df2, "DataFrame"))
344+
expect_true(3 == count(df2))
345+
346+
df3 <- agg(gd, age = "sum")
347+
expect_true(inherits(df3, "DataFrame"))
348+
expect_true(3 == count(df3))
349+
350+
df4 <- sum(gd, "age")
351+
expect_true(inherits(df4, "DataFrame"))
352+
expect_true(3 == count(df4))
353+
expect_true(3 == count(mean(gd, "age")))
354+
expect_true(3 == count(max(gd, "age")))
355+
})
356+
332357
test_that("sortDF() and orderBy() on a DataFrame", {
333358
df <- jsonFile(sqlCtx, jsonPath)
334359
sorted <- sortDF(df, df$age)

pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackendHandler.scala

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
package edu.berkeley.cs.amplab.sparkr
22

3-
import scala.collection.mutable.HashMap
3+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
44

5-
import java.io.ByteArrayInputStream
6-
import java.io.ByteArrayOutputStream
7-
import java.io.DataInputStream
8-
import java.io.DataOutputStream
5+
import scala.collection.mutable.HashMap
96

107
import io.netty.channel.ChannelHandler.Sharable
118
import io.netty.channel.ChannelHandlerContext
@@ -19,7 +16,8 @@ import edu.berkeley.cs.amplab.sparkr.SerDe._
1916
* this across connections ?
2017
*/
2118
@Sharable
22-
class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHandler[Array[Byte]] {
19+
class SparkRBackendHandler(server: SparkRBackend)
20+
extends SimpleChannelInboundHandler[Array[Byte]] {
2321

2422
override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]) {
2523
val bis = new ByteArrayInputStream(msg)
@@ -100,11 +98,18 @@ class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHa
10098
val methods = cls.get.getMethods
10199
val selectedMethods = methods.filter(m => m.getName == methodName)
102100
if (selectedMethods.length > 0) {
103-
val selectedMethod = selectedMethods.filter { x =>
101+
val methods = selectedMethods.filter { x =>
104102
matchMethod(numArgs, args, x.getParameterTypes)
105-
}.head
106-
107-
val ret = selectedMethod.invoke(obj, args:_*)
103+
}
104+
if (methods.isEmpty) {
105+
System.err.println(s"cannot find matching method ${cls.get}.$methodName. "
106+
+ s"Candidates are:")
107+
selectedMethods.foreach { method =>
108+
System.err.println(s"$methodName(${method.getParameterTypes.mkString(",")})")
109+
}
110+
throw new Exception(s"No matched method found for $cls.$methodName")
111+
}
112+
val ret = methods.head.invoke(obj, args:_*)
108113

109114
// Write status bit
110115
writeInt(dos, 0)
@@ -160,6 +165,7 @@ class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHa
160165
}
161166
}
162167
if (!parameterWrapperType.isInstance(args(i))) {
168+
System.err.println(s"arg $i not match: expected type $parameterWrapperType, but got ${args(i).getClass()}")
163169
return false
164170
}
165171
}

0 commit comments

Comments
 (0)