Skip to content

Commit

Permalink
[SPARK-25942][SQL] Aggregate expressions shouldn't be resolved on App…
Browse files Browse the repository at this point in the history
…endColumns

## What changes were proposed in this pull request?

`Dataset.groupByKey` will bring in new attributes from serializer. If key type is the same as original Dataset's object type, they have same serializer output and so the attribute names will conflict.

This won't be a problem at most of cases, if we don't refer conflict attributes:

```scala
val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
  .map(c => ClassData(c.a, c.b + 1))
  .groupByKey(p => p).count()
```

But if we use conflict attributes, `Analyzer` will complain about ambiguous references:

```scala
val ds = Seq(1, 2, 3).toDS()
val agg = ds.groupByKey(_ >= 2).agg(sum("value").as[Long], sum($"value" + 1).as[Long])
```

We have discussed two fixes apache#22944 (comment):

1. Implicitly add alias to key attribute:

Works for primitive type. But for product type, we can't implicitly add aliases to key attributes because we might need to access key attributes by  names in methods like `mapGroups`.

2. Detect conflict from key attributes and warn users to add alias manually

This might work, but needs to add some hacks to Analyzer or AttributeSeq.resolve.

This patch applies another simpler fix. We resolve aggregate expressions with `AppendColumns`'s children, instead of `AppendColumns`. `AppendColumns`'s output contains its children's output and serializer output, aggregate expressions shouldn't use serializer output.

## How was this patch tested?

Added test.

Closes apache#22944 from viirya/dataset_agg.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
viirya authored and cloud-fan committed Nov 13, 2018
1 parent 4b95562 commit f26cd18
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,12 @@ class Analyzer(
// rule: ResolveDeserializer.
case plan if containsDeserializer(plan.expressions) => plan

// SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of
// `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute
// names leading to ambiguous references exception.
case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) =>
a.mapExpressions(resolve(_, appendColumns))

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q.mapExpressions(resolve(_, q))
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class TypedColumn[-T, U](
inputEncoder: ExpressionEncoder[_],
inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)

// This only inserts inputs into typed aggregate expressions. For untyped aggregate expressions,
// the resolving is handled in the analyzer directly.
val newExpr = expr transform {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
ta.withInputInfo(
Expand Down
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
df.where($"city".contains(java.lang.Character.valueOf('A'))),
Seq(Row("Amsterdam")))
}

test("SPARK-25942: typed aggregation on primitive type") {
val ds = Seq(1, 2, 3).toDS()

val agg = ds.groupByKey(_ >= 2)
.agg(sum("value").as[Long], sum($"value" + 1).as[Long])
checkDatasetUnorderly(agg, (false, 1L, 2L), (true, 5L, 7L))
}

test("SPARK-25942: typed aggregation on product type") {
val ds = Seq((1, 2), (2, 3), (3, 4)).toDS()
val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 1).as[Long])
checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L))
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down

0 comments on commit f26cd18

Please sign in to comment.