Skip to content

Commit

Permalink
[SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL
Browse files Browse the repository at this point in the history
In this PR I delete a method that breaks type inference for aggregators (only in the REPL)

The error when this method is present is:
```
<console>:38: error: missing parameter type for expanded function ((x$2) => x$2._2)
              ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
```

Author: Michael Armbrust <michael@databricks.com>

Closes #9870 from marmbrus/dataset-repl-agg.
  • Loading branch information
marmbrus committed Nov 20, 2015
1 parent 58b4e4f commit 968acf3
Showing 3 changed files with 30 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite {
}
}

test("Datasets agg type-inference") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
| override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
| override def finish(reduction: N): N = reduction
|}
|
|def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
|val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
|ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
27 changes: 3 additions & 24 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
Original file line number Diff line number Diff line change
@@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql](
reduce(f.call _)
}

/**
* Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]].
* We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again.
*
* The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
*
* // Scala:
* import org.apache.spark.sql.functions._
* df.groupBy("department").agg(max("age"), sum("expense"))
*
* // Java:
* import static org.apache.spark.sql.functions.*;
* df.groupBy("department").agg(max("age"), sum("expense"));
* }}}
*
* We can also use `Aggregator.toColumn` to pass in typed aggregate functions.
*
* @since 1.6.0
*/
// This is here to prevent us from adding overloads that would be ambiguous.
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame =
groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
private def agg(exprs: Column*): DataFrame =
groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)

private def withEncoder(c: Column): Column = c match {
case tc: TypedColumn[_, _] =>
Original file line number Diff line number Diff line change
@@ -404,11 +404,9 @@ public String call(Tuple2<String, Integer> value) throws Exception {
grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());

Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
expr("sum(_2)"),
count("*"))
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG()));
Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
Assert.assertEquals(
Arrays.asList(
new Tuple4<>("a", 3, 3L, 2L),

0 comments on commit 968acf3

Please sign in to comment.