Skip to content

Commit

Permalink
chore: Add CometEvalMode enum to replace string literals (apache#539)
Browse files Browse the repository at this point in the history
* Add CometEvalMode enum

* address feedback
  • Loading branch information
andygrove authored Jun 7, 2024
1 parent f75aeef commit 311e13e
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 45 deletions.
6 changes: 3 additions & 3 deletions spark/src/main/scala/org/apache/comet/GenerateDocs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.io.Source

import org.apache.spark.sql.catalyst.expressions.Cast

import org.apache.comet.expressions.{CometCast, Compatible, Incompatible}
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible}

/**
* Utility for generating markdown documentation from the configs.
Expand Down Expand Up @@ -72,7 +72,7 @@ object GenerateDocs {
if (Cast.canCast(fromType, toType) && fromType != toType) {
val fromTypeName = fromType.typeName.replace("(10,2)", "")
val toTypeName = toType.typeName.replace("(10,2)", "")
CometCast.isSupported(fromType, toType, None, "LEGACY") match {
CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match {
case Compatible(notes) =>
val notesStr = notes.getOrElse("").trim
w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes)
Expand All @@ -89,7 +89,7 @@ object GenerateDocs {
if (Cast.canCast(fromType, toType) && fromType != toType) {
val fromTypeName = fromType.typeName.replace("(10,2)", "")
val toTypeName = toType.typeName.replace("(10,2)", "")
CometCast.isSupported(fromType, toType, None, "LEGACY") match {
CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match {
case Incompatible(notes) =>
val notesStr = notes.getOrElse("").trim
w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object CometCast {
fromType: DataType,
toType: DataType,
timeZoneId: Option[String],
evalMode: String): SupportLevel = {
evalMode: CometEvalMode.Value): SupportLevel = {

if (fromType == toType) {
return Compatible()
Expand Down Expand Up @@ -102,7 +102,7 @@ object CometCast {
private def canCastFromString(
toType: DataType,
timeZoneId: Option[String],
evalMode: String): SupportLevel = {
evalMode: CometEvalMode.Value): SupportLevel = {
toType match {
case DataTypes.BooleanType =>
Compatible()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.expressions

/**
* We cannot reference Spark's EvalMode directly because the package is different between Spark
* versions, so we copy it here.
*
* Expression evaluation modes.
* - LEGACY: the default evaluation mode, which is compliant to Hive SQL.
* - ANSI: a evaluation mode which is compliant to ANSI SQL standard.
* - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI evaluation mode
* except for returning null result on errors.
*/
object CometEvalMode extends Enumeration {
val LEGACY, ANSI, TRY = Value

def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) {
ANSI
} else {
LEGACY
}

def fromString(str: String): CometEvalMode.Value = CometEvalMode.withName(str)
}
57 changes: 22 additions & 35 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.comet.serde

import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
Expand All @@ -45,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo}
import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported}
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible, Unsupported}
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
Expand Down Expand Up @@ -578,6 +576,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}
}

def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = {
evalMode match {
case CometEvalMode.LEGACY => ExprOuterClass.EvalMode.LEGACY
case CometEvalMode.TRY => ExprOuterClass.EvalMode.TRY
case CometEvalMode.ANSI => ExprOuterClass.EvalMode.ANSI
case _ => throw new IllegalStateException(s"Invalid evalMode $evalMode")
}
}

/**
* Convert a Spark expression to protobuf.
*
Expand All @@ -590,18 +597,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
* @return
* The protobuf representation of the expression, or None if the expression is not supported
*/

def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode =
evalModeStr.toUpperCase(Locale.ROOT) match {
case "LEGACY" => ExprOuterClass.EvalMode.LEGACY
case "TRY" => ExprOuterClass.EvalMode.TRY
case "ANSI" => ExprOuterClass.EvalMode.ANSI
case invalid =>
throw new IllegalArgumentException(
s"Invalid eval mode '$invalid' "
) // Assuming we want to catch errors strictly
}

def exprToProto(
expr: Expression,
input: Seq[Attribute],
Expand All @@ -610,15 +605,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
timeZoneId: Option[String],
dt: DataType,
childExpr: Option[Expr],
evalMode: String): Option[Expr] = {
evalMode: CometEvalMode.Value): Option[Expr] = {
val dataType = serializeDataType(dt)
val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum

if (childExpr.isDefined && dataType.isDefined) {
val castBuilder = ExprOuterClass.Cast.newBuilder()
castBuilder.setChild(childExpr.get)
castBuilder.setDatatype(dataType.get)
castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf
castBuilder.setEvalMode(evalModeToProto(evalMode))

val timeZone = timeZoneId.getOrElse("UTC")
castBuilder.setTimezone(timeZone)
Expand Down Expand Up @@ -646,26 +640,26 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
inputs: Seq[Attribute],
dt: DataType,
timeZoneId: Option[String],
actualEvalModeStr: String): Option[Expr] = {
evalMode: CometEvalMode.Value): Option[Expr] = {

val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val castSupport =
CometCast.isSupported(child.dataType, dt, timeZoneId, actualEvalModeStr)
CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)

def getIncompatMessage(reason: Option[String]): String =
"Comet does not guarantee correct results for cast " +
s"from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $actualEvalModeStr" +
s"with timezone $timeZoneId and evalMode $evalMode" +
reason.map(str => s" ($str)").getOrElse("")

castSupport match {
case Compatible(_) =>
castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
castToProto(timeZoneId, dt, childExpr, evalMode)
case Incompatible(reason) =>
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
logWarning(getIncompatMessage(reason))
castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
castToProto(timeZoneId, dt, childExpr, evalMode)
} else {
withInfo(
expr,
Expand All @@ -677,7 +671,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $actualEvalModeStr")
s"with timezone $timeZoneId and evalMode $evalMode")
None
}
} else {
Expand All @@ -701,17 +695,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim

case UnaryExpression(child) if expr.prettyName == "trycast" =>
val timeZoneId = SQLConf.get.sessionLocalTimeZone
handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY")
handleCast(child, inputs, expr.dataType, Some(timeZoneId), CometEvalMode.TRY)

case Cast(child, dt, timeZoneId, evalMode) =>
val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
// Spark 3.2 & 3.3 has ansiEnabled boolean
if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
} else {
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}
handleCast(child, inputs, dt, timeZoneId, evalModeStr)
case c @ Cast(child, dt, timeZoneId, _) =>
handleCast(child, inputs, dt, timeZoneId, evalMode(c))

case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
val leftExpr = exprToProtoInternal(left, inputs)
Expand Down Expand Up @@ -2006,7 +1993,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
// TODO: Remove this once we have new DataFusion release which includes
// the fix: https://github.com/apache/arrow-datafusion/pull/9459
if (childExpr.isDefined) {
castToProto(None, a.dataType, childExpr, "LEGACY")
castToProto(None, a.dataType, childExpr, CometEvalMode.LEGACY)
} else {
withInfo(expr, a.children: _*)
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.comet.shims

import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -27,7 +28,10 @@ trait CometExprShim {
/**
* Returns a tuple of expressions for the `unhex` function.
*/
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(false))
}

protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled)
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.comet.shims

import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -27,7 +28,9 @@ trait CometExprShim {
/**
* Returns a tuple of expressions for the `unhex` function.
*/
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(false))
}

protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.comet.shims

import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -27,7 +28,19 @@ trait CometExprShim {
/**
* Returns a tuple of expressions for the `unhex` function.
*/
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(unhex.failOnError))
}

protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
}

object CometEvalModeUtil {
def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match {
case EvalMode.LEGACY => CometEvalMode.LEGACY
case EvalMode.TRY => CometEvalMode.TRY
case EvalMode.ANSI => CometEvalMode.ANSI
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.comet.shims

import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -30,4 +31,16 @@ trait CometExprShim {
protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(unhex.failOnError))
}

protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
}

object CometEvalModeUtil {
def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match {
case EvalMode.LEGACY => CometEvalMode.LEGACY
case EvalMode.TRY => CometEvalMode.TRY
case EvalMode.ANSI => CometEvalMode.ANSI
}
}

4 changes: 2 additions & 2 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType}

import org.apache.comet.expressions.{CometCast, Compatible}
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible}

class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

Expand Down Expand Up @@ -76,7 +76,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
} else {
val testIgnored =
tags.get(expectedTestName).exists(s => s.contains("org.scalatest.Ignore"))
CometCast.isSupported(fromType, toType, None, "LEGACY") match {
CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match {
case Compatible(_) =>
if (testIgnored) {
fail(
Expand Down

0 comments on commit 311e13e

Please sign in to comment.