Skip to content

Commit 13b0e54

Browse files
robert3005a10y
authored andcommitted
fix: Restore fast constant sum behaviour, don't fuzz floating point sum compute function (#5527)
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 81a7e19 commit 13b0e54

File tree

2 files changed

+10
-8
lines changed
  • fuzz/src/array
  • vortex-array/src/arrays/constant/compute

2 files changed

+10
-8
lines changed

fuzz/src/array/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
287287
(Action::Cast(to), ExpectedValue::Array(result))
288288
}
289289
ActionType::Sum => {
290+
// Do not try to fuzz float operations, they have unpredictable error behavior
291+
if current_array.dtype().is_float() {
292+
return Err(EmptyChoose);
293+
}
294+
290295
// Sum - returns a scalar, does NOT update current_array (terminal operation)
291296
let sum_result =
292297
sum_canonical_array(current_array.to_canonical()).vortex_unwrap();

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use num_traits::AsPrimitive;
45
use num_traits::CheckedAdd;
56
use num_traits::CheckedMul;
67
use vortex_dtype::DType;
@@ -159,13 +160,9 @@ fn sum_float(
159160
let v = primitive_scalar
160161
.as_::<f64>()
161162
.vortex_expect("cannot be null");
163+
let len_f64: f64 = array_len.as_();
162164

163-
// Preserve numerical behaviour of summation of floats by using a loop instead of simplifying to multiplication.
164-
let mut sum = initial;
165-
for _ in 0..array_len {
166-
sum += v;
167-
}
168-
Ok(Some(sum))
165+
Ok(Some(initial + v * len_f64))
169166
}
170167

171168
register_kernel!(SumKernelAdapter(ConstantVTable).lift());
@@ -298,8 +295,8 @@ mod tests {
298295
let sum =
299296
sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable)).vortex_unwrap();
300297
assert_eq!(
301-
sum,
302-
Scalar::primitive(-2048669274505641600000000000f64, Nullable)
298+
f64::try_from(sum).vortex_unwrap(),
299+
-2048669274505644600000000000f64
303300
);
304301
}
305302
}

0 commit comments

Comments
 (0)