Skip to content

Commit 968acf3

Browse files
committed
[SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL
In this PR I delete a method that breaks type inference for aggregators (only in the REPL) The error when this method is present is: ``` <console>:38: error: missing parameter type for expanded function ((x$2) => x$2._2) ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() ``` Author: Michael Armbrust <michael@databricks.com> Closes #9870 from marmbrus/dataset-repl-agg.
1 parent 58b4e4f commit 968acf3

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite {
339339
}
340340
}
341341

342+
test("Datasets agg type-inference") {
343+
val output = runInterpreter("local",
344+
"""
345+
|import org.apache.spark.sql.functions._
346+
|import org.apache.spark.sql.Encoder
347+
|import org.apache.spark.sql.expressions.Aggregator
348+
|import org.apache.spark.sql.TypedColumn
349+
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
350+
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
351+
| val numeric = implicitly[Numeric[N]]
352+
| override def zero: N = numeric.zero
353+
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
354+
| override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
355+
| override def finish(reduction: N): N = reduction
356+
|}
357+
|
358+
|def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
359+
|val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
360+
|ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
361+
""".stripMargin)
362+
assertDoesNotContain("error:", output)
363+
assertDoesNotContain("Exception", output)
364+
}
365+
342366
test("collecting objects of class defined in repl") {
343367
val output = runInterpreter("local[2]",
344368
"""

sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql](
146146
reduce(f.call _)
147147
}
148148

149-
/**
150-
* Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]].
151-
* We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again.
152-
*
153-
* The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
154-
*
155-
* {{{
156-
* // Selects the age of the oldest employee and the aggregate expense for each department
157-
*
158-
* // Scala:
159-
* import org.apache.spark.sql.functions._
160-
* df.groupBy("department").agg(max("age"), sum("expense"))
161-
*
162-
* // Java:
163-
* import static org.apache.spark.sql.functions.*;
164-
* df.groupBy("department").agg(max("age"), sum("expense"));
165-
* }}}
166-
*
167-
* We can also use `Aggregator.toColumn` to pass in typed aggregate functions.
168-
*
169-
* @since 1.6.0
170-
*/
149+
// This is here to prevent us from adding overloads that would be ambiguous.
171150
@scala.annotation.varargs
172-
def agg(expr: Column, exprs: Column*): DataFrame =
173-
groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
151+
private def agg(exprs: Column*): DataFrame =
152+
groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)
174153

175154
private def withEncoder(c: Column): Column = c match {
176155
case tc: TypedColumn[_, _] =>

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,9 @@ public String call(Tuple2<String, Integer> value) throws Exception {
404404
grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
405405
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
406406

407-
Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
408-
new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
409-
expr("sum(_2)"),
410-
count("*"))
411-
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG()));
407+
Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
408+
new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
409+
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
412410
Assert.assertEquals(
413411
Arrays.asList(
414412
new Tuple4<>("a", 3, 3L, 2L),

0 commit comments

Comments
 (0)