@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
23
23
import org .apache .spark .sql .catalyst .expressions .{Expression , PythonUDF }
24
24
import org .apache .spark .sql .catalyst .expressions .Literal .TrueLiteral
25
25
import org .apache .spark .sql .catalyst .plans .{LeftAnti , LeftSemi , PlanTest }
26
- import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , LogicalPlan }
26
+ import org .apache .spark .sql .catalyst .plans .logical .{Distinct , LocalRelation , LogicalPlan }
27
27
import org .apache .spark .sql .catalyst .rules .RuleExecutor
28
28
import org .apache .spark .sql .types .IntegerType
29
29
@@ -230,7 +230,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
230
230
}
231
231
}
232
232
233
- test(" SPARK-36194: Remove aggregation from aggregation " ) {
233
+ test(" SPARK-36194: Child distinct keys is the subset of required keys " ) {
234
234
val originalQuery = relation
235
235
.groupBy(' a )(' a , count(' b ).as(" cnt" ))
236
236
.groupBy(' a , ' cnt )(' a , ' cnt )
@@ -243,48 +243,38 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
243
243
comparePlans(optimized, correctAnswer)
244
244
}
245
245
246
- test(" SPARK-36194: Negative case: The grouping expressions not same" ) {
247
- Seq (LeftSemi , LeftAnti ).foreach { joinType =>
248
- val originalQuery = x.groupBy(' a , ' b )(' a , ' b )
249
- .join(y, joinType, Some (" x.a" .attr === " y.a" .attr && " x.b" .attr === " y.b" .attr))
250
- .groupBy(" x.a" .attr)(" x.a" .attr)
251
-
252
- val optimized = Optimize .execute(originalQuery.analyze)
253
- comparePlans(optimized, originalQuery.analyze)
254
- }
246
+ test(" SPARK-36194: Child distinct keys are subsets and aggregateExpressions are foldable" ) {
247
+ val originalQuery = x.groupBy(' a , ' b )(' a , ' b )
248
+ .join(y, LeftSemi , Some (" x.a" .attr === " y.a" .attr && " x.b" .attr === " y.b" .attr))
249
+ .groupBy(" x.a" .attr, " x.b" .attr)(TrueLiteral )
250
+ .analyze
251
+ val correctAnswer = x.groupBy(' a , ' b )(' a , ' b )
252
+ .join(y, LeftSemi , Some (" x.a" .attr === " y.a" .attr && " x.b" .attr === " y.b" .attr))
253
+ .select(TrueLiteral )
254
+ .analyze
255
+ val optimized = Optimize .execute(originalQuery)
256
+ comparePlans(optimized, correctAnswer)
255
257
}
256
258
257
- test(" SPARK-36194: Negative case: The aggregate expressions not the sub aggregateExprs " ) {
259
+ test(" SPARK-36194: Negative case: child distinct keys is not the subset of required keys " ) {
258
260
Seq (LeftSemi , LeftAnti ).foreach { joinType =>
259
- val originalQuery = x.groupBy(' a , ' b )(' a , ' b )
261
+ val originalQuery1 = x.groupBy(' a , ' b )(' a , ' b )
260
262
.join(y, joinType, Some (" x.a" .attr === " y.a" .attr && " x.b" .attr === " y.b" .attr))
261
- .groupBy(" x.a" .attr, " x.b" .attr)(TrueLiteral )
262
-
263
- val optimized = Optimize .execute(originalQuery.analyze)
264
- comparePlans(optimized, originalQuery.analyze)
265
- }
266
- }
263
+ .groupBy(" x.a" .attr)(" x.a" .attr)
264
+ .analyze
265
+ comparePlans(Optimize .execute(originalQuery1), originalQuery1)
267
266
268
- test(" SPARK-36194: Negative case: The aggregate expressions not same" ) {
269
- Seq (LeftSemi , LeftAnti ).foreach { joinType =>
270
- val originalQuery = x.groupBy(' a , ' b )(' a , ' b )
267
+ val originalQuery2 = x.groupBy(' a , ' b )(' a , ' b )
271
268
.join(y, joinType, Some (" x.a" .attr === " y.a" .attr && " x.b" .attr === " y.b" .attr))
272
269
.groupBy(" x.a" .attr)(count(" x.b" .attr))
273
-
274
- val optimized = Optimize .execute(originalQuery.analyze)
275
- comparePlans(optimized, originalQuery.analyze)
270
+ .analyze
271
+ comparePlans(Optimize .execute(originalQuery2), originalQuery2)
276
272
}
277
273
}
278
274
279
- test(" SPARK-36194: Negative case: The aggregate expressions with Literal" ) {
280
- Seq (LeftSemi , LeftAnti ).foreach { joinType =>
281
- val originalQuery = x.groupBy(' a , ' b )(' a , TrueLiteral )
282
- .join(y, joinType, Some (" x.a" .attr === " y.a" .attr))
283
- .groupBy(" x.a" .attr)(" x.a" .attr, TrueLiteral )
284
-
285
- val optimized = Optimize .execute(originalQuery.analyze)
286
- comparePlans(optimized, originalQuery.analyze)
287
- }
275
+ test(" SPARK-36194: Negative case: child distinct keys is empty" ) {
276
+ val originalQuery = Distinct (x.groupBy(' a , ' b )(' a , TrueLiteral )).analyze
277
+ comparePlans(Optimize .execute(originalQuery), originalQuery)
288
278
}
289
279
290
280
test(" SPARK-36194: Negative case: Remove aggregation from contains non-deterministic" ) {
0 commit comments