Skip to content

Commit 0b9d749

Browse files
committed
feat: add avg distinct support for decimal
1 parent 3bb414d commit 0b9d749

File tree

3 files changed

+256
-0
lines changed

3 files changed

+256
-0
lines changed

datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod decimal;
1819
mod numeric;
1920

21+
pub use decimal::DecimalDistinctAvgAccumulator;
2022
pub use numeric::Float64DistinctAvgAccumulator;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::{
19+
array::{ArrayRef, ArrowNumericType},
20+
datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType},
21+
};
22+
use datafusion_common::{Result, ScalarValue};
23+
use datafusion_expr_common::accumulator::Accumulator;
24+
use std::fmt::Debug;
25+
use std::mem::size_of_val;
26+
27+
use crate::aggregate::sum_distinct::DistinctSumAccumulator;
28+
use crate::utils::DecimalAverager;
29+
30+
/// Generic implementation of `AVG DISTINCT` for Decimal types.
31+
/// Handles both Decimal128Type and Decimal256Type.
32+
#[derive(Debug)]
33+
pub struct DecimalDistinctAvgAccumulator<T: DecimalType + Debug> {
34+
sum_accumulator: DistinctSumAccumulator<T>,
35+
sum_scale: i8,
36+
target_precision: u8,
37+
target_scale: i8,
38+
}
39+
40+
impl<T: DecimalType + Debug> DecimalDistinctAvgAccumulator<T> {
41+
pub fn with_decimal_params(
42+
sum_scale: i8,
43+
target_precision: u8,
44+
target_scale: i8,
45+
) -> Self {
46+
let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale);
47+
48+
Self {
49+
sum_accumulator: DistinctSumAccumulator::try_new(&data_type).unwrap(),
50+
sum_scale,
51+
target_precision,
52+
target_scale,
53+
}
54+
}
55+
}
56+
57+
impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
58+
for DecimalDistinctAvgAccumulator<T>
59+
{
60+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
61+
self.sum_accumulator.state()
62+
}
63+
64+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
65+
self.sum_accumulator.update_batch(values)
66+
}
67+
68+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
69+
self.sum_accumulator.merge_batch(states)
70+
}
71+
72+
fn evaluate(&mut self) -> Result<ScalarValue> {
73+
if self.sum_accumulator.distinct_count() == 0 {
74+
return ScalarValue::new_primitive::<T>(
75+
None,
76+
&T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
77+
);
78+
}
79+
80+
let sum_scalar = self.sum_accumulator.evaluate()?;
81+
82+
match sum_scalar {
83+
ScalarValue::Decimal128(Some(sum), _, _) => {
84+
let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
85+
self.sum_scale,
86+
self.target_precision,
87+
self.target_scale,
88+
)?;
89+
let avg = decimal_averager
90+
.avg(sum, self.sum_accumulator.distinct_count() as i128)?;
91+
Ok(ScalarValue::Decimal128(
92+
Some(avg),
93+
self.target_precision,
94+
self.target_scale,
95+
))
96+
}
97+
ScalarValue::Decimal256(Some(sum), _, _) => {
98+
let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
99+
self.sum_scale,
100+
self.target_precision,
101+
self.target_scale,
102+
)?;
103+
// `distinct_count` returns `u64`, but `avg` expects `i256`
104+
// first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow
105+
let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128;
106+
let count: i256 = i256::from_i128(distinct_cnt);
107+
let avg = decimal_averager.avg(sum, count)?;
108+
Ok(ScalarValue::Decimal256(
109+
Some(avg),
110+
self.target_precision,
111+
self.target_scale,
112+
))
113+
}
114+
115+
_ => unreachable!("Unsupported decimal type: {:?}", sum_scalar),
116+
}
117+
}
118+
119+
fn size(&self) -> usize {
120+
let fixed_size = size_of_val(self);
121+
122+
// Account for the size of the sum_accumulator with its contained values
123+
fixed_size + self.sum_accumulator.size()
124+
}
125+
}
126+
127+
#[cfg(test)]
128+
mod tests {
129+
use super::*;
130+
use arrow::array::{Decimal128Array, Decimal256Array};
131+
use std::sync::Arc;
132+
133+
#[test]
134+
fn test_decimal128_distinct_avg_accumulator() -> Result<()> {
135+
// (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null)
136+
// with precision 10, scale 4
137+
// As `single_distinct_to_groupby` will convert the input to a `GroupBy` plan,
138+
// we need to test it with rust api
139+
// See also `aggregate.slt`
140+
let precision = 10_u8;
141+
let scale = 4_i8;
142+
let array = Decimal128Array::from(vec![
143+
Some(100_0000), // 100.0000
144+
Some(125_0000), // 125.0000
145+
Some(175_0000), // 175.0000
146+
Some(200_0000), // 200.0000
147+
Some(200_0000), // 200.0000 (duplicate)
148+
Some(300_0000), // 300.0000
149+
None, // null
150+
None, // null
151+
])
152+
.with_precision_and_scale(precision, scale)?;
153+
154+
// Expected result for avg(distinct) should be 180.0000 with precision 14, scale 8
155+
let expected_result = ScalarValue::Decimal128(
156+
Some(180_00000000), // 180.00000000
157+
14, // target precision
158+
8, // target scale
159+
);
160+
161+
let arrays: Vec<ArrayRef> = vec![Arc::new(array)];
162+
163+
// Create accumulator with appropriate parameters
164+
let mut accumulator =
165+
DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
166+
scale, // input scale
167+
14, // target precision
168+
8, // target scale
169+
);
170+
171+
// Update the accumulator with input values
172+
accumulator.update_batch(&arrays)?;
173+
174+
// Evaluate the result
175+
let result = accumulator.evaluate()?;
176+
177+
// Assert that the result matches the expected value
178+
assert_eq!(result, expected_result);
179+
180+
Ok(())
181+
}
182+
183+
#[test]
184+
fn test_decimal256_distinct_avg_accumulator() -> Result<()> {
185+
// (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null)
186+
// with precision 50, scale 2
187+
let precision = 50_u8;
188+
let scale = 2_i8;
189+
190+
let array = Decimal256Array::from(vec![
191+
Some(i256::from_i128(100_00)), // 100.00
192+
Some(i256::from_i128(125_00)), // 125.00
193+
Some(i256::from_i128(175_00)), // 175.00
194+
Some(i256::from_i128(200_00)), // 200.00
195+
Some(i256::from_i128(200_00)), // 200.00 (duplicate)
196+
Some(i256::from_i128(300_00)), // 300.00
197+
None, // null
198+
None, // null
199+
])
200+
.with_precision_and_scale(precision, scale)?;
201+
202+
// Expected result for avg(distinct) should be 180.000000 with precision 54, scale 6
203+
let expected_result = ScalarValue::Decimal256(
204+
Some(i256::from_i128(180_000000)), // 180.000000
205+
54, // target precision
206+
6, // target scale
207+
);
208+
209+
let arrays: Vec<ArrayRef> = vec![Arc::new(array)];
210+
211+
// Create accumulator with appropriate parameters
212+
let mut accumulator =
213+
DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
214+
scale, // input scale
215+
54, // target precision
216+
6, // target scale
217+
);
218+
219+
// Update the accumulator with input values
220+
accumulator.update_batch(&arrays)?;
221+
222+
// Evaluate the result
223+
let result = accumulator.evaluate()?;
224+
225+
// Assert that the result matches the expected value
226+
assert_eq!(result, expected_result);
227+
228+
Ok(())
229+
}
230+
}

datafusion/functions-aggregate/src/average.rs

+24
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,30 @@ impl AggregateUDFImpl for Avg {
124124
match &data_type {
125125
// Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
126126
Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::new()?)),
127+
Decimal128(_, scale) => {
128+
let target_type = &acc_args.return_type;
129+
if let Decimal128(target_precision, target_scale) = target_type {
130+
Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
131+
*scale,
132+
*target_precision,
133+
*target_scale,
134+
)))
135+
} else {
136+
exec_err!("AVG(DISTINCT) for Decimal128 expected Decimal128 return type, got {}", target_type)
137+
}
138+
}
139+
Decimal256(_, scale) => {
140+
let target_type = &acc_args.return_type;
141+
if let Decimal256(target_precision, target_scale) = target_type {
142+
Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
143+
*scale,
144+
*target_precision,
145+
*target_scale,
146+
)))
147+
} else {
148+
exec_err!("AVG(DISTINCT) for Decimal256 expected Decimal256 return type, got {}", target_type)
149+
}
150+
}
127151
_ => exec_err!("AVG(DISTINCT) for {} not supported", data_type),
128152
}
129153
} else {

0 commit comments

Comments
 (0)