Skip to content

Commit

Permalink
chore: Improve ObjectHashAggregate fallback error message (#849)
Browse files Browse the repository at this point in the history
* add support for ObjectHashAggregate

* Revert a change
  • Loading branch information
andygrove authored Aug 20, 2024
1 parent befabdc commit 9205f0d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
Expand Down Expand Up @@ -424,16 +424,13 @@ class CometSparkSessionExtensions
op
}

case op @ HashAggregateExec(
_,
_,
_,
groupingExprs,
aggExprs,
_,
_,
resultExpressions,
child) =>
case op: BaseAggregateExec
if op.isInstanceOf[HashAggregateExec] ||
op.isInstanceOf[ObjectHashAggregateExec] =>
val groupingExprs = op.groupingExpressions
val aggExprs = op.aggregateExpressions
val resultExpressions = op.resultExpressions
val child = op.child
val modes = aggExprs.map(_.mode).distinct

if (!modes.isEmpty && modes.size != 1) {
Expand Down
22 changes: 11 additions & 11 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
Expand Down Expand Up @@ -2663,16 +2663,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}

case HashAggregateExec(
_,
_,
_,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
_,
resultExpressions,
child) if isCometOperatorEnabled(op.conf, CometConf.OPERATOR_AGGREGATE) =>
case aggregate: BaseAggregateExec
if (aggregate.isInstanceOf[HashAggregateExec] ||
aggregate.isInstanceOf[ObjectHashAggregateExec]) &&
isCometOperatorEnabled(op.conf, CometConf.OPERATOR_AGGREGATE) =>
val groupingExpressions = aggregate.groupingExpressions
val aggregateExpressions = aggregate.aggregateExpressions
val aggregateAttributes = aggregate.aggregateAttributes
val resultExpressions = aggregate.resultExpressions
val child = aggregate.child

if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
withInfo(op, "No group by or aggregation")
return None
Expand Down

0 comments on commit 9205f0d

Please sign in to comment.