-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Support ANSI mode sum expr #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe PR extends SUM aggregation to support 8/16/32/64-bit integer types via a new SumInteger UDF. The boolean fail_on_error field is replaced with an EvalMode enum in the protobuf definition. Scala serialization logic is updated, and comprehensive ANSI mode tests for overflow handling are added. Changes
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review completed. 1 suggestion posted.
| running_sum, | ||
| )?, | ||
| _ => { | ||
| panic!("Unsupported data type {}", values.data_type()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider returning a DataFusionError here instead of panicking on an unexpected type; a panic in executor code can crash the process. The GroupsAccumulator path already returns an error for the same case, so aligning behavior would be safer.
🤖 React with 👍 or 👎 to let us know if the comment was useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback:The AI reviewer is correct that library code should avoid panicking if possible. It is better to return an Err and let the user decide what to do with it. Panicking will abort the whole application.
Pull Request Review: ANSI Mode SUM Expression SupportThank you for this contribution! This PR adds important ANSI mode support for integer SUM aggregations. Here's my comprehensive review: ✅ Strengths
🔍 Issues & ConcernsCRITICAL: .DS_Store File Committed
Potential Bug: State Management in Try ModeLocation: native/spark-expr/src/agg_funcs/sum_int.rs:109-119 The initialization pattern for Try mode is unusual. Consider adding a comment explaining why Try mode needs sum: Some(0) initialization instead of None. Code Clarity: Nested Match StatementsLocation: native/spark-expr/src/agg_funcs/sum_int.rs:272-288 The nested match in merge_batch could be simplified. Consider extracting overflow handling into a helper method. Minor: Inconsistent Whitespace in Test NamesLocation: spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala Test names have extra spaces before the dash (e.g., "sum - null test" should be "sum - null test"). 🎯 Performance Considerations
🔒 Security Review
📋 Test Coverage AssessmentWell Covered: Int8/16/32/64 types, overflow/underflow, all eval modes, group by, null handling Missing Coverage (Consider adding):
📝 Code Quality Notes
✏️ RecommendationsMust Fix:
Should Fix:
Nice to Have:
📊 Overall AssessmentThis is a solid implementation that correctly implements ANSI mode integer SUM with proper overflow semantics. The code follows existing patterns, has good test coverage, and properly integrates with the protobuf/serde layer. Rating: 8/10 Main issues:
Once the .DS_Store file is removed and test names are fixed, this PR will be ready to merge. Generated by Claude Code Review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (6)
spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala (6)
3055-3067: Remove redundant assertion aftercheckSparkAnswerAndOperator.Line 3063 contains an explicit assertion that duplicates the validation already performed by
checkSparkAnswerAndOperator(res)on line 3062. The helper method already verifies that Comet produces the same results as Spark, making the explicit assertion redundant and reducing maintainability.Apply this diff:
withParquetTable( Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), "null_tbl") { val res = sql("SELECT sum(_1) FROM null_tbl") checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) }
3069-3081: Remove redundant assertion aftercheckSparkAnswerAndOperator.Same issue as the previous test - line 3077 contains a redundant assertion.
Apply this diff:
withParquetTable( Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), "null_tbl") { val res = sql("SELECT try_sum(_1) FROM null_tbl") checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) }
3083-3100: Remove redundant assertion aftercheckSparkAnswerAndOperator.Lines 3095-3096 contain a redundant assertion after the result has already been validated.
Apply this diff:
"tbl") { val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") checkSparkAnswerAndOperator(res) - assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) }
3102-3119: Remove redundant assertion aftercheckSparkAnswerAndOperator.Lines 3114-3115 contain a redundant assertion.
Apply this diff:
"tbl") { val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") checkSparkAnswerAndOperator(res) - assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) }
3176-3235: Inconsistent repartitioning strategy across similar test cases.Line 3183 uses
.repartition(2)while the similar test at line 3200 doesn't use any repartitioning, and other GROUP BY tests in this file don't use repartitioning either. This inconsistency makes it unclear whether:
- The repartitioning is intentional to test different execution paths
- It was added arbitrarily
For GROUP BY aggregation tests, consider either:
- Using consistent repartitioning across all similar tests, preferably with the grouping column:
.repartition(2, col("_2"))- Removing repartitioning if not specifically testing partitioning behavior
- Adding a comment explaining why repartitioning is used in some cases but not others
3055-3055: Minor: Inconsistent spacing in test names.The test names at lines 3055 and 3069 have double spaces before the dash ("sum - null test"), while other test names use single spaces. Consider standardizing to single spaces for consistency.
- test("ANSI support for sum - null test") { + test("ANSI support for sum - null test") {- test("ANSI support for try_sum - null test") { + test("ANSI support for try_sum - null test") {Also applies to: 3069-3069, 3083-3083, 3102-3102
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
native/.DS_Storeis excluded by!**/.DS_Store
📒 Files selected for processing (6)
native/core/src/execution/planner.rs(2 hunks)native/proto/src/proto/expr.proto(1 hunks)native/spark-expr/src/agg_funcs/mod.rs(2 hunks)native/spark-expr/src/agg_funcs/sum_int.rs(1 hunks)spark/src/main/scala/org/apache/comet/serde/aggregates.scala(3 hunks)spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-04T14:26:48.750Z
Learnt from: martin-augment
Repo: martin-augment/datafusion-comet PR: 7
File: native/spark-expr/src/math_funcs/abs.rs:201-302
Timestamp: 2025-11-04T14:26:48.750Z
Learning: In the abs function in native/spark-expr/src/math_funcs/abs.rs (Rust), NULL values for signed integers (Int8, Int16, Int32, Int64) and decimals (Decimal128, Decimal256) should return the argument as-is (e.g., ColumnarValue::Scalar(ScalarValue::Int8(None))) rather than panicking on unwrap().
Applied to files:
native/core/src/execution/planner.rsnative/spark-expr/src/agg_funcs/sum_int.rs
🧬 Code graph analysis (4)
native/core/src/execution/planner.rs (1)
native/spark-expr/src/agg_funcs/sum_int.rs (2)
try_new(41-51)new(106-121)
native/spark-expr/src/agg_funcs/sum_int.rs (1)
native/spark-expr/src/lib.rs (1)
arithmetic_overflow_error(110-114)
spark/src/main/scala/org/apache/comet/serde/aggregates.scala (1)
spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala (4)
QueryPlanSerde(63-1737)evalModeToProto(600-607)exprToProto(627-636)serializeDataType(291-385)
spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala (1)
spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala (7)
withParquetTable(390-398)withParquetTable(400-403)withParquetTable(405-408)checkSparkAnswerAndOperator(169-171)checkSparkAnswerAndOperator(173-177)checkSparkAnswerAndOperator(179-186)checkSparkMaybeThrows(262-270)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: claude-review
- GitHub Check: codex
🔇 Additional comments (2)
spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala (2)
3121-3174: Comprehensive overflow/underflow testing.Excellent coverage of SUM overflow and underflow scenarios across different integer types (Long, Int, Short, Byte). The test properly validates that:
- Long types can overflow/underflow and throw exceptions in ANSI mode
- Smaller integer types (Int, Short, Byte) are promoted to Long and don't overflow
3237-3274: Excellent GROUP BY aggregation test with proper repartitioning.This test demonstrates the correct pattern for testing GROUP BY aggregations:
- Consistently uses
.repartition(2, col("_2"))to distribute groups across partitions- Validates both overflow (returns NULL) and non-overflow scenarios
- Tests partial overflow (some groups overflow, others don't) which is a valuable edge case
- Covers multiple integer types
This repartitioning strategy ensures that the GROUP BY accumulator behavior is tested across multiple partitions, which is important for validating distributed aggregation correctness.
| use crate::{arithmetic_overflow_error, EvalMode}; | ||
| use arrow::array::{ | ||
| cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, | ||
| Int64Array, PrimitiveArray, | ||
| }; | ||
| use arrow::datatypes::{ | ||
| ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, | ||
| }; | ||
| use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; | ||
| use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; | ||
| use datafusion::logical_expr::Volatility::Immutable; | ||
| use datafusion::logical_expr::{ | ||
| Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, | ||
| }; | ||
| use std::{any::Any, sync::Arc}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the missing import for size_of_val.
size_of_val isn’t in scope, so both size() implementations fail to compile. Please either import it or use the fully-qualified path.
Apply this diff:
-use crate::{arithmetic_overflow_error, EvalMode};
+use crate::{arithmetic_overflow_error, EvalMode};
+use std::mem::size_of_val;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| use crate::{arithmetic_overflow_error, EvalMode}; | |
| use arrow::array::{ | |
| cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, | |
| Int64Array, PrimitiveArray, | |
| }; | |
| use arrow::datatypes::{ | |
| ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, | |
| }; | |
| use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; | |
| use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; | |
| use datafusion::logical_expr::Volatility::Immutable; | |
| use datafusion::logical_expr::{ | |
| Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, | |
| }; | |
| use std::{any::Any, sync::Arc}; | |
| use crate::{arithmetic_overflow_error, EvalMode}; | |
| use std::mem::size_of_val; | |
| use arrow::array::{ | |
| cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, | |
| Int64Array, PrimitiveArray, | |
| }; | |
| use arrow::datatypes::{ | |
| ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, | |
| }; | |
| use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; | |
| use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; | |
| use datafusion::logical_expr::Volatility::Immutable; | |
| use datafusion::logical_expr::{ | |
| Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, | |
| }; | |
| use std::{any::Any, sync::Arc}; |
🤖 Prompt for AI Agents
In native/spark-expr/src/agg_funcs/sum_int.rs around lines 18 to 32, the code
calls size_of_val but never imports it; add the missing import by bringing
size_of_val into scope (e.g., add use std::mem::size_of_val;) or replace calls
with the fully-qualified path std::mem::size_of_val(...) so the two size()
implementations compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback:The CodeRabbit AI reviewer is correct that size_of_val function is not imported and it will cause build failure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@martin-augment, thank you for confirming! I'm glad I could catch this build issue before it caused problems. The missing size_of_val import should be resolved with the suggested fix.
🐰✨
🧠 Learnings used
Learnt from: martin-augment
Repo: martin-augment/datafusion-comet PR: 7
File: native/spark-expr/src/math_funcs/abs.rs:201-302
Timestamp: 2025-11-04T14:26:48.750Z
Learning: In the abs function in native/spark-expr/src/math_funcs/abs.rs (Rust), NULL values for signed integers (Int8, Int16, Int32, Int64) and decimals (Decimal128, Decimal256) should return the argument as-is (e.g., ColumnarValue::Scalar(ScalarValue::Int8(None))) rather than panicking on unwrap().
| fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { | ||
| let that_sum_array = states[0].as_primitive::<Int64Type>(); | ||
| let that_sum = if that_sum_array.is_null(0) { | ||
| None | ||
| } else { | ||
| Some(that_sum_array.value(0)) | ||
| }; | ||
|
|
||
| // Check for overflow for early termination | ||
| if self.eval_mode == EvalMode::Try { | ||
| let that_has_all_nulls = states[1].as_boolean().value(0); | ||
| let that_overflowed = !that_has_all_nulls && that_sum.is_none(); | ||
| let this_overflowed = !self.has_all_nulls && self.sum.is_none(); | ||
| if that_overflowed || this_overflowed { | ||
| self.sum = None; | ||
| self.has_all_nulls = false; | ||
| return Ok(()); | ||
| } | ||
| self.has_all_nulls = self.has_all_nulls && that_has_all_nulls; | ||
| if that_has_all_nulls { | ||
| return Ok(()); | ||
| } | ||
| if self.has_all_nulls { | ||
| self.sum = that_sum; | ||
| return Ok(()); | ||
| } | ||
| } else { | ||
| if that_sum.is_none() { | ||
| return Ok(()); | ||
| } | ||
| if self.sum.is_none() { | ||
| self.sum = that_sum; | ||
| return Ok(()); | ||
| } | ||
| } | ||
|
|
||
| let left = self.sum.unwrap(); | ||
| let right = that_sum.unwrap(); | ||
|
|
||
| match self.eval_mode { | ||
| EvalMode::Legacy => { | ||
| self.sum = Some(left.add_wrapping(right)); | ||
| } | ||
| EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) { | ||
| Ok(v) => self.sum = Some(v), | ||
| Err(_) => { | ||
| if self.eval_mode == EvalMode::Ansi { | ||
| return Err(DataFusionError::from(arithmetic_overflow_error("integer"))); | ||
| } else { | ||
| self.sum = None; | ||
| self.has_all_nulls = false; | ||
| } | ||
| } | ||
| }, | ||
| } | ||
| Ok(()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge all incoming states, not just the first row.
merge_batch currently inspects only states[...].value(0). When an upstream partial aggregate ships more than one state row (common once batches spill), every row beyond index 0 is silently ignored, so the final SUM can drop contributions. Iterate over every row like DataFusion’s built-ins do.
Apply this diff (sketch):
- let that_sum_array = states[0].as_primitive::<Int64Type>();
- let that_sum = if that_sum_array.is_null(0) {
- None
- } else {
- Some(that_sum_array.value(0))
- };
+ let that_sum_array = states[0].as_primitive::<Int64Type>();
+ let rows = that_sum_array.len();
+ for row in 0..rows {
+ let that_sum = if that_sum_array.is_null(row) {
+ None
+ } else {
+ Some(that_sum_array.value(row))
+ };
+ // existing overflow and accumulation logic goes here, using `row`
+ // when reading boolean state in TRY mode as well.
+ }
- // existing overflow/accumulation logic...
+ // existing overflow/accumulation logic must move inside the loop and
+ // use `row` for the boolean state (`states[1]`) when TRY mode is active.Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In native/spark-expr/src/agg_funcs/sum_int.rs around lines 240 to 295,
merge_batch only reads the first row of the incoming partial-aggregate state
(value(0)) and thus drops contributions when multiple state rows are present;
replace the single-index access with a loop over all rows (use the arrays'
length) and apply the same per-row logic: for each i read that_sum from
states[0].as_primitive::<Int64Type>().value(i)/is_null(i) and that_has_all_nulls
from states[1].as_boolean().value(i), then perform the existing
overflow/has_all_nulls checks and merging logic per row (respecting EvalMode
behavior: Try/Ansi/Legacy), updating self.sum and self.has_all_nulls
incrementally and returning early only when the original logic would for that
row; ensure unwraps are safe by only unwrapping when the per-row null checks
passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback:The CodeRabbit AI reviewer is correct that all items/rows of the state should be used for the aggregation/accumulation. Otherwise for spilled batches it will miss all items but the first one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@martin-augment, thank you for confirming! I'm glad the analysis was helpful in catching this issue before it could impact production workloads with spilled aggregations.
If you'd like any assistance while implementing the fix—such as verifying the loop logic handles all three EvalMode variants correctly, or ensuring the boolean state array is indexed properly in Try mode—feel free to ask.
🧠 Learnings used
Learnt from: martin-augment
Repo: martin-augment/datafusion-comet PR: 7
File: native/spark-expr/src/math_funcs/abs.rs:201-302
Timestamp: 2025-11-04T14:26:48.750Z
Learning: In the abs function in native/spark-expr/src/math_funcs/abs.rs (Rust), NULL values for signed integers (Int8, Int16, Int32, Int64) and decimals (Decimal128, Decimal256) should return the argument as-is (e.g., ColumnarValue::Scalar(ScalarValue::Int8(None))) rather than panicking on unwrap().
|
Findings
|
value:incorrect-but-reasonable; category:bug; feedback:The Codex AI reviewer is not correct! The protocol buffer communication is used as a transport protocol between Spark and DataFusion at runtime. It is not possible to use one version of the protocol at one side and another version at the other side. |
value:useful; category:bug; feedback:The Codex AI reviewer is correct that size_of_val function is not imported and it will cause build failure. |
value:good-to-have; category:bug; feedback:The Codex AI reviewer is correct that this file should not be added to Git. It should be added to .gitignore, so that it is not re-added to Git later again. It does not cause any hard though. |
2600: To review by AI
Note
Add integer SUM UDF with ANSI/TRY overflow semantics and plumb eval_mode through proto/serde to planner, with comprehensive tests.
SumIntegerUDF with overflow semantics forInt8/16/32/64supportingLegacy/Ansi/Trymodes, including grouped accumulators and state handling.SUMtoSumIntegerusingeval_modefrom protobuf; decimals still useSumDecimal; others fall back tosum_udaf.proto.Sumfieldfail_on_error→eval_mode: EvalModeand plumb through Spark serde (CometSum) viaevalModeToProto(CometEvalModeUtil...).sum_int(native/spark-expr).SUMandtry_sumtests (nulls, overflow/underflow, GROUP BY) inCometExpressionSuite.Written by Cursor Bugbot for commit 970d693. This will update automatically on new commits. Configure here.