@@ -2998,227 +2998,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29982998 }
29992999 }
30003000
3001- test(" ANSI support for sum - null test" ) {
3002- Seq (true , false ).foreach { ansiEnabled =>
3003- withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
3004- withParquetTable(
3005- Seq ((null .asInstanceOf [java.lang.Long ], " a" ), (null .asInstanceOf [java.lang.Long ], " b" )),
3006- " null_tbl" ) {
3007- val res = sql(" SELECT sum(_1) FROM null_tbl" )
3008- checkSparkAnswerAndOperator(res)
3009- assert(res.collect() === Array (Row (null )))
3010- }
3011- }
3012- }
3013- }
3014-
3015- test(" ANSI support for try_sum - null test" ) {
3016- Seq (true , false ).foreach { ansiEnabled =>
3017- withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
3018- withParquetTable(
3019- Seq ((null .asInstanceOf [java.lang.Long ], " a" ), (null .asInstanceOf [java.lang.Long ], " b" )),
3020- " null_tbl" ) {
3021- val res = sql(" SELECT try_sum(_1) FROM null_tbl" )
3022- checkSparkAnswerAndOperator(res)
3023- assert(res.collect() === Array (Row (null )))
3024- }
3025- }
3026- }
3027- }
3028-
3029- test(" ANSI support for sum - null test (group by)" ) {
3030- Seq (true , false ).foreach { ansiEnabled =>
3031- withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
3032- withParquetTable(
3033- Seq (
3034- (null .asInstanceOf [java.lang.Long ], " a" ),
3035- (null .asInstanceOf [java.lang.Long ], " a" ),
3036- (null .asInstanceOf [java.lang.Long ], " b" ),
3037- (null .asInstanceOf [java.lang.Long ], " b" ),
3038- (null .asInstanceOf [java.lang.Long ], " b" )),
3039- " tbl" ) {
3040- val res = sql(" SELECT _2, sum(_1) FROM tbl group by 1" )
3041- checkSparkAnswerAndOperator(res)
3042- assert(res.orderBy(col(" _2" )).collect() === Array (Row (" a" , null ), Row (" b" , null )))
3043- }
3044- }
3045- }
3046- }
3047-
3048- test(" ANSI support for try_sum - null test (group by)" ) {
3049- Seq (true , false ).foreach { ansiEnabled =>
3050- withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
3051- withParquetTable(
3052- Seq (
3053- (null .asInstanceOf [java.lang.Long ], " a" ),
3054- (null .asInstanceOf [java.lang.Long ], " a" ),
3055- (null .asInstanceOf [java.lang.Long ], " b" ),
3056- (null .asInstanceOf [java.lang.Long ], " b" ),
3057- (null .asInstanceOf [java.lang.Long ], " b" )),
3058- " tbl" ) {
3059- val res = sql(" SELECT _2, try_sum(_1) FROM tbl group by 1" )
3060- checkSparkAnswerAndOperator(res)
3061- assert(res.orderBy(col(" _2" )).collect() === Array (Row (" a" , null ), Row (" b" , null )))
3062- }
3063- }
3064- }
3065- }
3066-
3067- test(" ANSI support - SUM function" ) {
3068- Seq (true , false ).foreach { ansiEnabled =>
3069- withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
3070- // Test long overflow
3071- withParquetTable(Seq ((Long .MaxValue , 1L ), (100L , 1L )), " tbl" ) {
3072- val res = sql(" SELECT SUM(_1) FROM tbl" )
3073- if (ansiEnabled) {
3074- checkSparkAnswerMaybeThrows(res) match {
3075- case (Some (sparkExc), Some (cometExc)) =>
3076- assert(sparkExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3077- assert(cometExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3078- case _ => fail(" Exception should be thrown for Long overflow in ANSI mode" )
3079- }
3080- } else {
3081- checkSparkAnswerAndOperator(res)
3082- }
3083- }
3084- // Test long underflow
3085- withParquetTable(Seq ((Long .MinValue , 1L ), (- 100L , 1L )), " tbl" ) {
3086- val res = sql(" SELECT SUM(_1) FROM tbl" )
3087- if (ansiEnabled) {
3088- checkSparkAnswerMaybeThrows(res) match {
3089- case (Some (sparkExc), Some (cometExc)) =>
3090- assert(sparkExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3091- assert(cometExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3092- case _ => fail(" Exception should be thrown for Long underflow in ANSI mode" )
3093- }
3094- } else {
3095- checkSparkAnswerAndOperator(res)
3096- }
3097- }
3098- // Test Int SUM (should not overflow)
3099- withParquetTable(Seq ((Int .MaxValue , 1 ), (Int .MaxValue , 1 ), (100 , 1 )), " tbl" ) {
3100- val res = sql(" SELECT SUM(_1) FROM tbl" )
3101- checkSparkAnswerAndOperator(res)
3102- }
3103- // Test Short SUM (should not overflow)
3104- withParquetTable(
3105- Seq ((Short .MaxValue , 1 .toShort), (Short .MaxValue , 1 .toShort), (100 .toShort, 1 .toShort)),
3106- " tbl" ) {
3107- val res = sql(" SELECT SUM(_1) FROM tbl" )
3108- checkSparkAnswerAndOperator(res)
3109- }
3110-
3111- // Test Byte SUM (should not overflow)
3112- withParquetTable(
3113- Seq ((Byte .MaxValue , 1 .toByte), (Byte .MaxValue , 1 .toByte), (10 .toByte, 1 .toByte)),
3114- " tbl" ) {
3115- val res = sql(" SELECT SUM(_1) FROM tbl" )
3116- checkSparkAnswerAndOperator(res)
3117- }
3118- }
3119- }
3120- }
3121-
3122- test(" ANSI support for SUM - GROUP BY" ) {
3123- // Test Long overflow with GROUP BY to test GroupAccumulator with ANSI support
3124- Seq (true , false ).foreach { ansiEnabled =>
3125- withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
3126- withParquetTable(
3127- Seq ((Long .MaxValue , 1 ), (100L , 1 ), (Long .MaxValue , 2 ), (200L , 2 )),
3128- " tbl" ) {
3129- val res = sql(" SELECT _2, SUM(_1) FROM tbl GROUP BY _2" ).repartition(2 )
3130- if (ansiEnabled) {
3131- checkSparkAnswerMaybeThrows(res) match {
3132- case (Some (sparkExc), Some (cometExc)) =>
3133- assert(sparkExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3134- assert(cometExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3135- case _ =>
3136- fail(" Exception should be thrown for Long overflow with GROUP BY in ANSI mode" )
3137- }
3138- } else {
3139- checkSparkAnswerAndOperator(res)
3140- }
3141- }
3142-
3143- withParquetTable(
3144- Seq ((Long .MinValue , 1 ), (- 100L , 1 ), (Long .MinValue , 2 ), (- 200L , 2 )),
3145- " tbl" ) {
3146- val res = sql(" SELECT _2, SUM(_1) FROM tbl GROUP BY _2" )
3147- if (ansiEnabled) {
3148- checkSparkAnswerMaybeThrows(res) match {
3149- case (Some (sparkExc), Some (cometExc)) =>
3150- assert(sparkExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3151- assert(cometExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
3152- case _ =>
3153- fail(" Exception should be thrown for Long underflow with GROUP BY in ANSI mode" )
3154- }
3155- } else {
3156- checkSparkAnswerAndOperator(res)
3157- }
3158- }
3159- // Test Int with GROUP BY
3160- withParquetTable(Seq ((Int .MaxValue , 1 ), (Int .MaxValue , 1 ), (100 , 2 ), (200 , 2 )), " tbl" ) {
3161- val res = sql(" SELECT _2, SUM(_1) FROM tbl GROUP BY _2" )
3162- checkSparkAnswerAndOperator(res)
3163- }
3164- // Test Short with GROUP BY
3165- withParquetTable(
3166- Seq ((Short .MaxValue , 1 ), (Short .MaxValue , 1 ), (100 .toShort, 2 ), (200 .toShort, 2 )),
3167- " tbl" ) {
3168- val res = sql(" SELECT _2, SUM(_1) FROM tbl GROUP BY _2" )
3169- checkSparkAnswerAndOperator(res)
3170- }
3171-
3172- // Test Byte with GROUP BY
3173- withParquetTable(
3174- Seq ((Byte .MaxValue , 1 ), (Byte .MaxValue , 1 ), (10 .toByte, 2 ), (20 .toByte, 2 )),
3175- " tbl" ) {
3176- val res = sql(" SELECT _2, SUM(_1) FROM tbl GROUP BY _2" )
3177- checkSparkAnswerAndOperator(res)
3178- }
3179- }
3180- }
3181- }
3182-
3183- test(" try_sum overflow - with GROUP BY" ) {
3184- // Test Long overflow with GROUP BY - some groups overflow while some don't
3185- withParquetTable(Seq ((Long .MaxValue , 1 ), (100L , 1 ), (200L , 2 ), (300L , 2 )), " tbl" ) {
3186- val res = sql(" SELECT _2, try_sum(_1) FROM tbl GROUP BY _2" ).repartition(2 , col(" _2" ))
3187- // first group should return NULL (overflow) and group 2 should return 500
3188- checkSparkAnswerAndOperator(res)
3189- }
3190-
3191- // Test Long underflow with GROUP BY
3192- withParquetTable(Seq ((Long .MinValue , 1 ), (- 100L , 1 ), (- 200L , 2 ), (- 300L , 2 )), " tbl" ) {
3193- val res = sql(" SELECT _2, try_sum(_1) FROM tbl GROUP BY _2" ).repartition(2 , col(" _2" ))
3194- // first group should return NULL (underflow), second group should return neg 500
3195- checkSparkAnswerAndOperator(res)
3196- }
3197-
3198- // Test all groups overflow
3199- withParquetTable(Seq ((Long .MaxValue , 1 ), (100L , 1 ), (Long .MaxValue , 2 ), (100L , 2 )), " tbl" ) {
3200- val res = sql(" SELECT _2, try_sum(_1) FROM tbl GROUP BY _2" ).repartition(2 , col(" _2" ))
3201- // Both groups should return NULL
3202- checkSparkAnswerAndOperator(res)
3203- }
3204-
3205- // Test Short with GROUP BY (should NOT overflow)
3206- withParquetTable(
3207- Seq ((Short .MaxValue , 1 ), (Short .MaxValue , 1 ), (100 .toShort, 2 ), (200 .toShort, 2 )),
3208- " tbl" ) {
3209- val res = sql(" SELECT _2, try_sum(_1) FROM tbl GROUP BY _2" ).repartition(2 , col(" _2" ))
3210- checkSparkAnswerAndOperator(res)
3211- }
3212-
3213- // Test Byte with GROUP BY (no overflow)
3214- withParquetTable(
3215- Seq ((Byte .MaxValue , 1 ), (Byte .MaxValue , 1 ), (10 .toByte, 2 ), (20 .toByte, 2 )),
3216- " tbl" ) {
3217- val res = sql(" SELECT _2, try_sum(_1) FROM tbl GROUP BY _2" ).repartition(2 , col(" _2" ))
3218- checkSparkAnswerAndOperator(res)
3219- }
3220- }
3221-
32223001 test(" test integral divide overflow for decimal" ) {
32233002 if (isSpark40Plus) {
32243003 Seq (true , false )
0 commit comments