Skip to content

Commit 9a6be74

Browse files
committed
include grouping columns in agg()
add docs for groupBy() and agg()
1 parent 09ff163 commit 9a6be74

File tree

5 files changed

+84
-21
lines changed

5 files changed

+84
-21
lines changed

pkg/R/DataFrame.R

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,8 +697,23 @@ setMethod("toRDD",
697697
#'
698698
#' Groups the DataFrame using the specified columns, so we can run aggregation on them.
699699
#'
700+
#' @param x a DataFrame
701+
#' @return a GroupedData
702+
#' @seealso GroupedData
703+
#' @rdname DataFrame
704+
#' @export
705+
#' @examples
706+
#' \dontrun {
707+
#' # Compute the average for all numeric columns grouped by department.
708+
#' avg(groupBy(df, "department"))
709+
#'
710+
#' // Compute the max age and average salary, grouped by department and gender.
711+
#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max")
712+
#' }
700713
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
701714

715+
#' @rdname DataFrame
716+
#' @export
702717
setMethod("groupBy",
703718
signature(x = "DataFrame"),
704719
function(x, ...) {
@@ -712,7 +727,12 @@ setMethod("groupBy",
712727
groupedData(sgd)
713728
})
714729

715-
730+
#' Agg
731+
#'
732+
#' Compute aggregates by specifying a list of columns
733+
#'
734+
#' @rdname DataFrame
735+
#' @export
716736
setMethod("agg",
717737
signature(x = "DataFrame"),
718738
function(x, ...) {

pkg/R/group.R

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,39 @@
1-
############################## GroupedData ########################################
1+
# group.R - GroupedData class and methods implemented in S4 OO classes
22

3+
setOldClass("jobj")
4+
5+
#' @title S4 class that represents a GroupedData
6+
#' @description GroupedDatas can be created using groupBy() on a DataFrame
7+
#' @rdname GroupedData
8+
#' @seealso groupBy
9+
#'
10+
#' @param sgd A Java object reference to the backing Scala GroupedData
11+
#' @export
312
setClass("GroupedData",
4-
slots = list(env = "environment",
5-
sgd = "jobj"))
13+
slots = list(sgd = "jobj"))
614

715
setMethod("initialize", "GroupedData", function(.Object, sgd) {
8-
.Object@env <- new.env()
916
.Object@sgd <- sgd
1017
.Object
1118
})
1219

20+
#' @rdname DataFrame
1321
groupedData <- function(sgd) {
1422
new("GroupedData", sgd)
1523
}
1624

25+
26+
#' Count
27+
#'
28+
#' Count the number of rows for each group.
29+
#' The resulting DataFrame will also contain the grouping columns.
30+
#'
31+
#' @param x a GroupedData
32+
#' @return a DataFrame
33+
#' @export
34+
#' @examples
35+
#' \dontrun {
36+
#' }
1737
setMethod("count",
1838
signature(x = "GroupedData"),
1939
function(x) {
@@ -23,9 +43,13 @@ setMethod("count",
2343
#' Agg
2444
#'
2545
#' Aggregates on the entire DataFrame without groups.
46+
#' The resulting DataFrame will also contain the grouping columns.
2647
#'
2748
#' df2 <- agg(df, <column> = <aggFunction>)
2849
#' df2 <- agg(df, newColName = aggFunction(column))
50+
#'
51+
#' @param x a GroupedData
52+
#' @return a DataFrame
2953
#' @examples
3054
#' \dontrun{
3155
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
@@ -51,15 +75,17 @@ setMethod("agg",
5175
}
5276
}
5377
jcols <- lapply(cols, function(c) { c@jc })
54-
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
78+
# the GroupedData.agg(col, cols*) API does not contain grouping Column
79+
sdf <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "aggWithGrouping",
80+
x@sgd, listToSeq(jcols))
5581
} else {
5682
stop("agg can only support Column or character")
5783
}
5884
dataFrame(sdf)
5985
})
6086

61-
#' sum/mean/avg/min/max
6287

88+
# sum/mean/avg/min/max
6389
methods <- c("sum", "mean", "avg", "min", "max")
6490

6591
createMethod <- function(name) {

pkg/inst/tests/test_sparkSQL.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ test_that("group by", {
368368
expect_true(1 == count(df1))
369369
df1 <- agg(df, age2 = max(df$age))
370370
expect_true(1 == count(df1))
371-
expect_true(columns(df1) == c("age2"))
371+
expect_equal(columns(df1), c("age2"))
372372

373373
gd <- groupBy(df, "name")
374374
expect_true(inherits(gd, "GroupedData"))
@@ -380,6 +380,11 @@ test_that("group by", {
380380
expect_true(inherits(df3, "DataFrame"))
381381
expect_true(3 == count(df3))
382382

383+
df3 <- agg(gd, age = sum(df$age))
384+
expect_true(inherits(df3, "DataFrame"))
385+
expect_true(3 == count(df3))
386+
expect_equal(columns(df3), c("name", "age"))
387+
383388
df4 <- sum(gd, "age")
384389
expect_true(inherits(df4, "DataFrame"))
385390
expect_true(3 == count(df4))

pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SQLUtils.scala

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
package edu.berkeley.cs.amplab.sparkr
22

3-
import java.io.ByteArrayOutputStream
4-
import java.io.DataOutputStream
3+
import java.io.{ByteArrayOutputStream, DataOutputStream}
54

6-
import org.apache.spark.rdd.RDD
75
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
8-
import org.apache.spark.sql.{SQLContext, DataFrame, Row, SaveMode}
9-
10-
import edu.berkeley.cs.amplab.sparkr.SerDe._
6+
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
7+
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
118

129
object SQLUtils {
1310
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
@@ -18,6 +15,22 @@ object SQLUtils {
1815
arr.toSeq
1916
}
2017

18+
// A helper to include grouping columns in Agg()
19+
def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = {
20+
val aggExprs = exprs.map{ col =>
21+
val f = col.getClass.getDeclaredField("expr")
22+
f.setAccessible(true)
23+
val expr = f.get(col).asInstanceOf[Expression]
24+
expr match {
25+
case expr: NamedExpression => expr
26+
case expr: Expression => Alias(expr, expr.simpleString)()
27+
}
28+
}
29+
val toDF = gd.getClass.getDeclaredMethods.filter(f => f.getName == "toDF").head
30+
toDF.setAccessible(true)
31+
toDF.invoke(gd, aggExprs).asInstanceOf[DataFrame]
32+
}
33+
2134
def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
2235
df.map(r => rowToRBytes(r))
2336
}

pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackendHandler.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,28 @@ class SparkRBackendHandler(server: SparkRBackend)
8080
dis: DataInputStream,
8181
dos: DataOutputStream) {
8282
var obj: Object = null
83-
var cls: Option[Class[_]] = None
8483
try {
85-
if (isStatic) {
86-
cls = Some(Class.forName(objId))
84+
val cls = if (isStatic) {
85+
Class.forName(objId)
8786
} else {
8887
JVMObjectTracker.get(objId) match {
8988
case None => throw new IllegalArgumentException("Object not found " + objId)
9089
case Some(o) =>
91-
cls = Some(o.getClass)
9290
obj = o
91+
o.getClass
9392
}
9493
}
9594

9695
val args = readArgs(numArgs, dis)
9796

98-
val methods = cls.get.getMethods
97+
val methods = cls.getMethods
9998
val selectedMethods = methods.filter(m => m.getName == methodName)
10099
if (selectedMethods.length > 0) {
101100
val methods = selectedMethods.filter { x =>
102101
matchMethod(numArgs, args, x.getParameterTypes)
103102
}
104103
if (methods.isEmpty) {
105-
System.err.println(s"cannot find matching method ${cls.get}.$methodName. "
104+
System.err.println(s"cannot find matching method ${cls}.$methodName. "
106105
+ s"Candidates are:")
107106
selectedMethods.foreach { method =>
108107
System.err.println(s"$methodName(${method.getParameterTypes.mkString(",")})")
@@ -116,7 +115,7 @@ class SparkRBackendHandler(server: SparkRBackend)
116115
writeObject(dos, ret.asInstanceOf[AnyRef])
117116
} else if (methodName == "<init>") {
118117
// methodName should be "<init>" for constructor
119-
val ctor = cls.get.getConstructors.filter { x =>
118+
val ctor = cls.getConstructors.filter { x =>
120119
matchMethod(numArgs, args, x.getParameterTypes)
121120
}.head
122121

0 commit comments

Comments
 (0)