Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change result type of count/count_distinct from uint64 to int64 #2636

Merged
merged 1 commit into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions datafusion/core/src/optimizer/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\
let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[#test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -255,8 +255,8 @@ mod tests {
.aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
.build()?;

let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):UInt64;N]\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\
let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[Int32(2) * #test.b AS alias1]], aggr=[[]] [alias1:Int32]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -273,8 +273,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N]\
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -294,7 +294,7 @@ mod tests {
.build()?;

// Do nothing
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(DISTINCT test.c):UInt64;N]\
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -319,8 +319,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N, MAX(alias1):UInt32;N]\
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -340,7 +340,7 @@ mod tests {
.build()?;

// Do nothing
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(test.c):UInt64;N]\
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_eq(&plan, expected);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ mod tests {

// c3 is small int

let count: &UInt64Array = as_primitive_array(&columns[0]);
let count: &Int64Array = as_primitive_array(&columns[0]);
assert_eq!(count.value(0), 100);
assert_eq!(count.value(99), 100);

Expand Down
22 changes: 11 additions & 11 deletions datafusion/core/tests/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ async fn parquet_distinct_partition_col() -> Result<()> {
.await?;

let mut max_limit = match ScalarValue::try_from_array(results[0].column(0), 0)? {
ScalarValue::UInt64(Some(count)) => count,
s => panic!("Expected count as Int64 found {}", s),
ScalarValue::Int64(Some(count)) => count,
s => panic!("Expected count as Int64 found {}", s.get_datatype()),
};

max_limit += 1;
Expand All @@ -117,54 +117,54 @@ async fn parquet_distinct_partition_col() -> Result<()> {
let last_row_idx = last_batch.num_rows() - 1;
let mut min_limit =
match ScalarValue::try_from_array(last_batch.column(0), last_row_idx)? {
ScalarValue::UInt64(Some(count)) => count,
s => panic!("Expected count as Int64 found {}", s),
ScalarValue::Int64(Some(count)) => count,
s => panic!("Expected count as Int64 found {}", s.get_datatype()),
};

min_limit -= 1;

let sql_cross_partition_boundary = format!("SELECT month FROM t limit {}", max_limit);
let resulting_limit: u64 = ctx
let resulting_limit: i64 = ctx
.sql(sql_cross_partition_boundary.as_str())
.await?
.collect()
.await?
.into_iter()
.map(|r| r.num_rows() as u64)
.map(|r| r.num_rows() as i64)
.sum();

assert_eq!(max_limit, resulting_limit);

let sql_within_partition_boundary =
format!("SELECT month from t limit {}", min_limit);
let resulting_limit: u64 = ctx
let resulting_limit: i64 = ctx
.sql(sql_within_partition_boundary.as_str())
.await?
.collect()
.await?
.into_iter()
.map(|r| r.num_rows() as u64)
.map(|r| r.num_rows() as i64)
.sum();

assert_eq!(min_limit, resulting_limit);

let month = match ScalarValue::try_from_array(results[0].column(1), 0)? {
ScalarValue::Utf8(Some(month)) => month,
s => panic!("Expected count as Int64 found {}", s),
s => panic!("Expected count as Int64 found {}", s.get_datatype()),
};

let sql_on_partition_boundary = format!(
"SELECT month from t where month = '{}' LIMIT {}",
month,
max_limit - 1
);
let resulting_limit: u64 = ctx
let resulting_limit: i64 = ctx
.sql(sql_on_partition_boundary.as_str())
.await?
.collect()
.await?
.into_iter()
.map(|r| r.num_rows() as u64)
.map(|r| r.num_rows() as i64)
.sum();
let partition_row_count = max_limit - 1;
assert_eq!(partition_row_count, resulting_limit);
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{as_primitive_array, Int32Builder, UInt64Array};
use arrow::array::{as_primitive_array, Int32Builder, Int64Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
Expand Down Expand Up @@ -170,7 +170,7 @@ impl TableProvider for CustomProvider {
}
}

async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()> {
async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<()> {
let provider = CustomProvider {
zero_batch: create_batch(0, 10)?,
one_batch: create_batch(1, 5)?,
Expand All @@ -183,7 +183,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()
.aggregate(vec![], vec![count(col("flag"))])?;

let results = df.collect().await?;
let result_col: &UInt64Array = as_primitive_array(results[0].column(0));
let result_col: &Int64Array = as_primitive_array(results[0].column(0));
assert_eq!(result_col.value(0), expected_count);

ctx.register_table("data", Arc::new(provider))?;
Expand All @@ -193,7 +193,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()
.collect()
.await?;

let sql_result_col: &UInt64Array = as_primitive_array(sql_results[0].column(0));
let sql_result_col: &Int64Array = as_primitive_array(sql_results[0].column(0));
assert_eq!(sql_result_col.value(0), expected_count);

Ok(())
Expand Down
3 changes: 1 addition & 2 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ pub fn return_type(
let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?;

match fun {
// TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64.
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Ok(DataType::UInt64)
Ok(DataType::Int64)
}
AggregateFunction::Max | AggregateFunction::Min => {
// For min and max agg function, the returned type is same as input type.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ mod tests {
fn test_count_return_type() -> Result<()> {
let fun = WindowFunction::from_str("count")?;
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::UInt64, observed);
assert_eq!(DataType::Int64, observed);

let observed = return_type(&fun, &[DataType::UInt64])?;
assert_eq!(DataType::UInt64, observed);
assert_eq!(DataType::Int64, observed);

Ok(())
}
Expand Down
10 changes: 5 additions & 5 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ mod tests {
assert!(result_agg_phy_exprs.as_any().is::<Count>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::UInt64, true),
Field::new("c1", DataType::Int64, true),
result_agg_phy_exprs.field().unwrap()
);
}
Expand Down Expand Up @@ -347,7 +347,7 @@ mod tests {
assert!(result_distinct.as_any().is::<DistinctCount>());
assert_eq!("c1", result_distinct.name());
assert_eq!(
Field::new("c1", DataType::UInt64, true),
Field::new("c1", DataType::Int64, true),
result_distinct.field().unwrap()
);
}
Expand Down Expand Up @@ -954,14 +954,14 @@ mod tests {
#[test]
fn test_count_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?;
assert_eq!(DataType::UInt64, observed);
assert_eq!(DataType::Int64, observed);

let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
assert_eq!(DataType::UInt64, observed);
assert_eq!(DataType::Int64, observed);

let observed =
return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?;
assert_eq!(DataType::UInt64, observed);
assert_eq!(DataType::Int64, observed);
Ok(())
}

Expand Down
46 changes: 22 additions & 24 deletions datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ use std::sync::Arc;

use crate::aggregate::row_accumulator::RowAccumulator;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::array::Int64Array;
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::{
array::{ArrayRef, UInt64Array},
datatypes::Field,
};
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::Accumulator;
Expand Down Expand Up @@ -110,7 +108,7 @@ impl AggregateExpr for Count {

#[derive(Debug)]
struct CountAccumulator {
count: u64,
count: i64,
}

impl CountAccumulator {
Expand All @@ -123,12 +121,12 @@ impl CountAccumulator {
impl Accumulator for CountAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count += (array.len() - array.data().null_count()) as u64;
self.count += (array.len() - array.data().null_count()) as i64;
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
let counts = states[0].as_any().downcast_ref::<Int64Array>().unwrap();
let delta = &compute::sum(counts);
if let Some(d) = delta {
self.count += *d;
Expand All @@ -137,11 +135,11 @@ impl Accumulator for CountAccumulator {
}

fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::UInt64(Some(self.count))])
Ok(vec![ScalarValue::Int64(Some(self.count))])
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::UInt64(Some(self.count)))
Ok(ScalarValue::Int64(Some(self.count)))
}
}

Expand Down Expand Up @@ -173,16 +171,16 @@ impl RowAccumulator for CountRowAccumulator {
states: &[ArrayRef],
accessor: &mut RowAccessor,
) -> Result<()> {
let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
let counts = states[0].as_any().downcast_ref::<Int64Array>().unwrap();
let delta = &compute::sum(counts);
if let Some(d) = delta {
accessor.add_u64(self.state_index, *d);
accessor.add_i64(self.state_index, *d);
}
Ok(())
}

fn evaluate(&self, accessor: &RowAccessor) -> Result<ScalarValue> {
Ok(accessor.get_as_scalar(&DataType::UInt64, self.state_index))
Ok(accessor.get_as_scalar(&DataType::Int64, self.state_index))
}

#[inline(always)]
Expand All @@ -208,8 +206,8 @@ mod tests {
a,
DataType::Int32,
Count,
ScalarValue::from(5u64),
DataType::UInt64
ScalarValue::from(5i64),
DataType::Int64
)
}

Expand All @@ -227,8 +225,8 @@ mod tests {
a,
DataType::Int32,
Count,
ScalarValue::from(3u64),
DataType::UInt64
ScalarValue::from(3i64),
DataType::Int64
)
}

Expand All @@ -241,8 +239,8 @@ mod tests {
a,
DataType::Boolean,
Count,
ScalarValue::from(0u64),
DataType::UInt64
ScalarValue::from(0i64),
DataType::Int64
)
}

Expand All @@ -254,8 +252,8 @@ mod tests {
a,
DataType::Boolean,
Count,
ScalarValue::from(0u64),
DataType::UInt64
ScalarValue::from(0i64),
DataType::Int64
)
}

Expand All @@ -267,8 +265,8 @@ mod tests {
a,
DataType::Utf8,
Count,
ScalarValue::from(5u64),
DataType::UInt64
ScalarValue::from(5i64),
DataType::Int64
)
}

Expand All @@ -280,8 +278,8 @@ mod tests {
a,
DataType::LargeUtf8,
Count,
ScalarValue::from(5u64),
DataType::UInt64
ScalarValue::from(5i64),
DataType::Int64
)
}
}
Loading