Skip to content

Commit 5f3464b

Browse files
committed
sum_ansi_mode_checks_fix_tests_rebase_main
1 parent 7cd491c commit 5f3464b

File tree

6 files changed

+249
-273
lines changed

6 files changed

+249
-273
lines changed

native/spark-expr/src/agg_funcs/sum_int.rs

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
use crate::{arithmetic_overflow_error, EvalMode};
1919
use arrow::array::{
20-
cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray,
21-
Int64Array, PrimitiveArray,
20+
as_primitive_array, cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType,
21+
BooleanArray, Int64Array, PrimitiveArray,
2222
};
2323
use arrow::datatypes::{
2424
ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type,
@@ -138,7 +138,12 @@ impl Accumulator for SumIntegerAccumulator {
138138
{
139139
for i in 0..int_array.len() {
140140
if !int_array.is_null(i) {
141-
let v = int_array.value(i).to_i64().unwrap();
141+
let v = int_array.value(i).to_i64().ok_or_else(|| {
142+
DataFusionError::Internal(format!(
143+
"Failed to convert value {:?} to i64",
144+
int_array.value(i)
145+
))
146+
})?;
142147
match eval_mode {
143148
EvalMode::Legacy => {
144149
sum = v.add_wrapping(sum);
@@ -175,34 +180,22 @@ impl Accumulator for SumIntegerAccumulator {
175180
let running_sum = self.sum.unwrap_or(0);
176181
let sum = match values.data_type() {
177182
DataType::Int64 => update_sum_internal(
178-
values
179-
.as_any()
180-
.downcast_ref::<PrimitiveArray<Int64Type>>()
181-
.unwrap(),
183+
as_primitive_array::<Int64Type>(values),
182184
self.eval_mode,
183185
running_sum,
184186
)?,
185187
DataType::Int32 => update_sum_internal(
186-
values
187-
.as_any()
188-
.downcast_ref::<PrimitiveArray<Int32Type>>()
189-
.unwrap(),
188+
as_primitive_array::<Int32Type>(values),
190189
self.eval_mode,
191190
running_sum,
192191
)?,
193192
DataType::Int16 => update_sum_internal(
194-
values
195-
.as_any()
196-
.downcast_ref::<PrimitiveArray<Int16Type>>()
197-
.unwrap(),
193+
as_primitive_array::<Int16Type>(values),
198194
self.eval_mode,
199195
running_sum,
200196
)?,
201197
DataType::Int8 => update_sum_internal(
202-
values
203-
.as_any()
204-
.downcast_ref::<PrimitiveArray<Int8Type>>()
205-
.unwrap(),
198+
as_primitive_array::<Int8Type>(values),
206199
self.eval_mode,
207200
running_sum,
208201
)?,
@@ -278,8 +271,17 @@ impl Accumulator for SumIntegerAccumulator {
278271
}
279272
}
280273

281-
let left = self.sum.unwrap();
282-
let right = that_sum.unwrap();
274+
// safe to unwrap (since we checked nulls above) but handling error just in case state is corrupt
275+
let left = self.sum.ok_or_else(|| {
276+
DataFusionError::Internal(
277+
"Invalid state in merging batch. Current batch's is None".to_string(),
278+
)
279+
})?;
280+
let right = that_sum.ok_or_else(|| {
281+
DataFusionError::Internal(
282+
"Invalid state in merging batch. Incoming sum to is None".to_string(),
283+
)
284+
})?;
283285

284286
match self.eval_mode {
285287
EvalMode::Legacy => {
@@ -392,40 +394,28 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
392394

393395
match values.data_type() {
394396
DataType::Int64 => update_groups_sum_internal(
395-
values
396-
.as_any()
397-
.downcast_ref::<PrimitiveArray<Int64Type>>()
398-
.unwrap(),
397+
as_primitive_array::<Int64Type>(values),
399398
group_indices,
400399
&mut self.sums,
401400
&mut self.has_all_nulls,
402401
self.eval_mode,
403402
)?,
404403
DataType::Int32 => update_groups_sum_internal(
405-
values
406-
.as_any()
407-
.downcast_ref::<PrimitiveArray<Int32Type>>()
408-
.unwrap(),
404+
as_primitive_array::<Int32Type>(values),
409405
group_indices,
410406
&mut self.sums,
411407
&mut self.has_all_nulls,
412408
self.eval_mode,
413409
)?,
414410
DataType::Int16 => update_groups_sum_internal(
415-
values
416-
.as_any()
417-
.downcast_ref::<PrimitiveArray<Int16Type>>()
418-
.unwrap(),
411+
as_primitive_array::<Int16Type>(values),
419412
group_indices,
420413
&mut self.sums,
421414
&mut self.has_all_nulls,
422415
self.eval_mode,
423416
)?,
424417
DataType::Int8 => update_groups_sum_internal(
425-
values
426-
.as_any()
427-
.downcast_ref::<PrimitiveArray<Int8Type>>()
428-
.unwrap(),
418+
as_primitive_array::<Int8Type>(values),
429419
group_indices,
430420
&mut self.sums,
431421
&mut self.has_all_nulls,

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ import org.apache.comet.shims.CometExprShim
4444
*/
4545
object QueryPlanSerde extends Logging with CometExprShim {
4646

47-
private val integerTypes = Seq(ByteType, ShortType, IntegerType, LongType)
48-
4947
private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
5048
classOf[ArrayAppend] -> CometArrayAppend,
5149
classOf[ArrayCompact] -> CometArrayCompact,

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,6 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
213213

214214
object CometSum extends CometAggregateExpressionSerde[Sum] {
215215

216-
override def getSupportLevel(sum: Sum): SupportLevel = {
217-
sum.evalMode match {
218-
case EvalMode.ANSI =>
219-
Incompatible(Some("ANSI mode is not supported"))
220-
case EvalMode.TRY =>
221-
Incompatible(Some("TRY mode is not supported"))
222-
case _ =>
223-
Compatible()
224-
}
225-
}
226-
227216
override def convert(
228217
aggExpr: AggregateExpression,
229218
sum: Sum,

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 0 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,227 +2998,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29982998
}
29992999
}
30003000

3001-
test("ANSI support for sum - null test") {
3002-
Seq(true, false).foreach { ansiEnabled =>
3003-
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
3004-
withParquetTable(
3005-
Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")),
3006-
"null_tbl") {
3007-
val res = sql("SELECT sum(_1) FROM null_tbl")
3008-
checkSparkAnswerAndOperator(res)
3009-
assert(res.collect() === Array(Row(null)))
3010-
}
3011-
}
3012-
}
3013-
}
3014-
3015-
test("ANSI support for try_sum - null test") {
3016-
Seq(true, false).foreach { ansiEnabled =>
3017-
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
3018-
withParquetTable(
3019-
Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")),
3020-
"null_tbl") {
3021-
val res = sql("SELECT try_sum(_1) FROM null_tbl")
3022-
checkSparkAnswerAndOperator(res)
3023-
assert(res.collect() === Array(Row(null)))
3024-
}
3025-
}
3026-
}
3027-
}
3028-
3029-
test("ANSI support for sum - null test (group by)") {
3030-
Seq(true, false).foreach { ansiEnabled =>
3031-
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
3032-
withParquetTable(
3033-
Seq(
3034-
(null.asInstanceOf[java.lang.Long], "a"),
3035-
(null.asInstanceOf[java.lang.Long], "a"),
3036-
(null.asInstanceOf[java.lang.Long], "b"),
3037-
(null.asInstanceOf[java.lang.Long], "b"),
3038-
(null.asInstanceOf[java.lang.Long], "b")),
3039-
"tbl") {
3040-
val res = sql("SELECT _2, sum(_1) FROM tbl group by 1")
3041-
checkSparkAnswerAndOperator(res)
3042-
assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null)))
3043-
}
3044-
}
3045-
}
3046-
}
3047-
3048-
test("ANSI support for try_sum - null test (group by)") {
3049-
Seq(true, false).foreach { ansiEnabled =>
3050-
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
3051-
withParquetTable(
3052-
Seq(
3053-
(null.asInstanceOf[java.lang.Long], "a"),
3054-
(null.asInstanceOf[java.lang.Long], "a"),
3055-
(null.asInstanceOf[java.lang.Long], "b"),
3056-
(null.asInstanceOf[java.lang.Long], "b"),
3057-
(null.asInstanceOf[java.lang.Long], "b")),
3058-
"tbl") {
3059-
val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
3060-
checkSparkAnswerAndOperator(res)
3061-
assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null)))
3062-
}
3063-
}
3064-
}
3065-
}
3066-
3067-
test("ANSI support - SUM function") {
3068-
Seq(true, false).foreach { ansiEnabled =>
3069-
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
3070-
// Test long overflow
3071-
withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") {
3072-
val res = sql("SELECT SUM(_1) FROM tbl")
3073-
if (ansiEnabled) {
3074-
checkSparkAnswerMaybeThrows(res) match {
3075-
case (Some(sparkExc), Some(cometExc)) =>
3076-
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3077-
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3078-
case _ => fail("Exception should be thrown for Long overflow in ANSI mode")
3079-
}
3080-
} else {
3081-
checkSparkAnswerAndOperator(res)
3082-
}
3083-
}
3084-
// Test long underflow
3085-
withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") {
3086-
val res = sql("SELECT SUM(_1) FROM tbl")
3087-
if (ansiEnabled) {
3088-
checkSparkAnswerMaybeThrows(res) match {
3089-
case (Some(sparkExc), Some(cometExc)) =>
3090-
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3091-
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3092-
case _ => fail("Exception should be thrown for Long underflow in ANSI mode")
3093-
}
3094-
} else {
3095-
checkSparkAnswerAndOperator(res)
3096-
}
3097-
}
3098-
// Test Int SUM (should not overflow)
3099-
withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)), "tbl") {
3100-
val res = sql("SELECT SUM(_1) FROM tbl")
3101-
checkSparkAnswerAndOperator(res)
3102-
}
3103-
// Test Short SUM (should not overflow)
3104-
withParquetTable(
3105-
Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), (100.toShort, 1.toShort)),
3106-
"tbl") {
3107-
val res = sql("SELECT SUM(_1) FROM tbl")
3108-
checkSparkAnswerAndOperator(res)
3109-
}
3110-
3111-
// Test Byte SUM (should not overflow)
3112-
withParquetTable(
3113-
Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte), (10.toByte, 1.toByte)),
3114-
"tbl") {
3115-
val res = sql("SELECT SUM(_1) FROM tbl")
3116-
checkSparkAnswerAndOperator(res)
3117-
}
3118-
}
3119-
}
3120-
}
3121-
3122-
test("ANSI support for SUM - GROUP BY") {
3123-
// Test Long overflow with GROUP BY to test GroupAccumulator with ANSI support
3124-
Seq(true, false).foreach { ansiEnabled =>
3125-
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
3126-
withParquetTable(
3127-
Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)),
3128-
"tbl") {
3129-
val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2)
3130-
if (ansiEnabled) {
3131-
checkSparkAnswerMaybeThrows(res) match {
3132-
case (Some(sparkExc), Some(cometExc)) =>
3133-
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3134-
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3135-
case _ =>
3136-
fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode")
3137-
}
3138-
} else {
3139-
checkSparkAnswerAndOperator(res)
3140-
}
3141-
}
3142-
3143-
withParquetTable(
3144-
Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)),
3145-
"tbl") {
3146-
val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
3147-
if (ansiEnabled) {
3148-
checkSparkAnswerMaybeThrows(res) match {
3149-
case (Some(sparkExc), Some(cometExc)) =>
3150-
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3151-
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
3152-
case _ =>
3153-
fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode")
3154-
}
3155-
} else {
3156-
checkSparkAnswerAndOperator(res)
3157-
}
3158-
}
3159-
// Test Int with GROUP BY
3160-
withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") {
3161-
val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
3162-
checkSparkAnswerAndOperator(res)
3163-
}
3164-
// Test Short with GROUP BY
3165-
withParquetTable(
3166-
Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)),
3167-
"tbl") {
3168-
val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
3169-
checkSparkAnswerAndOperator(res)
3170-
}
3171-
3172-
// Test Byte with GROUP BY
3173-
withParquetTable(
3174-
Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)),
3175-
"tbl") {
3176-
val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
3177-
checkSparkAnswerAndOperator(res)
3178-
}
3179-
}
3180-
}
3181-
}
3182-
3183-
test("try_sum overflow - with GROUP BY") {
3184-
// Test Long overflow with GROUP BY - some groups overflow while some don't
3185-
withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") {
3186-
val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2"))
3187-
// first group should return NULL (overflow) and group 2 should return 500
3188-
checkSparkAnswerAndOperator(res)
3189-
}
3190-
3191-
// Test Long underflow with GROUP BY
3192-
withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 2)), "tbl") {
3193-
val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2"))
3194-
// first group should return NULL (underflow), second group should return neg 500
3195-
checkSparkAnswerAndOperator(res)
3196-
}
3197-
3198-
// Test all groups overflow
3199-
withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (100L, 2)), "tbl") {
3200-
val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2"))
3201-
// Both groups should return NULL
3202-
checkSparkAnswerAndOperator(res)
3203-
}
3204-
3205-
// Test Short with GROUP BY (should NOT overflow)
3206-
withParquetTable(
3207-
Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)),
3208-
"tbl") {
3209-
val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2"))
3210-
checkSparkAnswerAndOperator(res)
3211-
}
3212-
3213-
// Test Byte with GROUP BY (no overflow)
3214-
withParquetTable(
3215-
Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)),
3216-
"tbl") {
3217-
val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2"))
3218-
checkSparkAnswerAndOperator(res)
3219-
}
3220-
}
3221-
32223001
test("test integral divide overflow for decimal") {
32233002
if (isSpark40Plus) {
32243003
Seq(true, false)

0 commit comments

Comments
 (0)