@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121
2222import scala .math .min
2323
24- import org .apache .spark .sql .catalyst .expressions .{Add , Attribute , Divide , EqualTo , EvalMode , Expression , If , IntegralDivide , Literal , Multiply , Remainder , Subtract }
24+ import org .apache .spark .sql .catalyst .expressions .{Add , Attribute , Cast , Divide , EqualTo , EvalMode , Expression , If , IntegralDivide , Literal , Multiply , Remainder , Subtract }
2525import org .apache .spark .sql .types .{ByteType , DataType , DecimalType , DoubleType , FloatType , IntegerType , LongType , ShortType }
2626
2727import org .apache .comet .CometSparkSessionExtensions .withInfo
@@ -201,7 +201,6 @@ object CometIntegralDivide extends CometExpressionSerde with MathBase {
201201 inputs : Seq [Attribute ],
202202 binding : Boolean ): Option [ExprOuterClass .Expr ] = {
203203 val div = expr.asInstanceOf [IntegralDivide ]
204- val rightExpr = nullIfWhenPrimitive(div.right)
205204
206205 if (! supportedDataType(div.left.dataType)) {
207206 withInfo(div, s " Unsupported datatype ${div.left.dataType}" )
@@ -212,17 +211,28 @@ object CometIntegralDivide extends CometExpressionSerde with MathBase {
212211 return None
213212 }
214213
215- val dataType = (div.left.dataType, div.right.dataType) match {
214+ // Precision is set to 19 (max precision for a numerical data type except DecimalType)
215+
216+ val left =
217+ if (div.left.dataType.isInstanceOf [DecimalType ]) div.left
218+ else Cast (div.left, DecimalType (19 , 0 ))
219+ val right =
220+ if (div.right.dataType.isInstanceOf [DecimalType ]) div.right
221+ else Cast (div.right, DecimalType (19 , 0 ))
222+
223+ val rightExpr = nullIfWhenPrimitive(right)
224+
225+ val dataType = (left.dataType, right.dataType) match {
216226 case (l : DecimalType , r : DecimalType ) =>
217227 // copy from IntegralDivide.resultDecimalType
218228 val intDig = l.precision - l.scale + r.scale
219229 DecimalType (min(if (intDig == 0 ) 1 else intDig, DecimalType .MAX_PRECISION ), 0 )
220- case _ => div. left.dataType
230+ case _ => left.dataType
221231 }
222232
223233 val divideExpr = createMathExpression(
224234 expr,
225- div. left,
235+ left,
226236 rightExpr,
227237 inputs,
228238 binding,
0 commit comments